提交 c8911895 编写于 作者: C chengduoZH

update sparse gradient parameter with reduce and broadcast

上级 5ff1ef36
...@@ -37,25 +37,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -37,25 +37,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale, platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
bool use_nccl_allreduce)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs), nccl_ctxs_(nccl_ctxs) {
use_nccl_allreduce_(use_nccl_allreduce) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale, const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
bool use_nccl_allreduce)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes) {
use_nccl_allreduce_(use_nccl_allreduce) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
...@@ -121,8 +116,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -121,8 +116,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
size_t cur_device_id = 0; // size_t cur_device_id = 0;
size_t update_sparse_gp_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
...@@ -162,14 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -162,14 +157,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (use_nccl_allreduce_) { if (IsSparseGradient(og)) {
InsertNCCLAllReduceOp(&result, og); CreateReduceOp(&result, update_sparse_gp_device_id, og);
} else { var_name_on_devices[update_sparse_gp_device_id].emplace(og);
CreateReduceOp(&result, cur_device_id, og); bcast_var_name_set[update_sparse_gp_device_id].emplace(
var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix))); og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size(); } else {
InsertNCCLAllReduceOp(&result, og);
} }
} }
} }
...@@ -205,13 +199,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -205,13 +199,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
auto og_var = local_scopes_[0]->FindVar(og);
PADDLE_ENFORCE_NOT_NULL(og_var);
return og_var->IsType<SelectedRows>();
}
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices, const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const { const OpDesc &op) const {
if (use_nccl_allreduce_) {
return -1;
}
int var_dev_id = -1; int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) { for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break; if (var_dev_id != -1) break;
......
...@@ -36,13 +36,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -36,13 +36,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs,
bool use_default_grad_scale, bool use_nccl_allreduce); bool use_default_grad_scale);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale, bool use_nccl_allreduce); bool use_default_grad_scale);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -60,7 +60,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -60,7 +60,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool use_nccl_allreduce_;
bool use_default_grad_scale_; bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
...@@ -99,6 +98,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -99,6 +98,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
* nullptr if not found. * nullptr if not found.
*/ */
OpDesc *GetSendOpDesc(const ProgramDesc &program) const; OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(const std::string &og) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale, bool use_nccl_allreduce) bool use_default_grad_scale)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), use_default_grad_scale, use_nccl_allreduce); member_->nccl_ctxs_.get(), use_default_grad_scale);
#else #else
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
member_->places_, loss_var_name, params, member_->local_scopes_, params, member_->local_scopes_,
use_default_grad_scale, use_nccl_allreduce); use_default_grad_scale);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
......
...@@ -40,8 +40,7 @@ class ParallelExecutor { ...@@ -40,8 +40,7 @@ class ParallelExecutor {
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale);
bool use_nccl_allreduce);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -502,12 +502,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -502,12 +502,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes, Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale) {
bool use_nccl_allreduce) {
new (&self) ParallelExecutor( new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars, num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes, main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale, use_nccl_allreduce); allow_op_delay, use_default_grad_scale);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
...@@ -30,8 +30,7 @@ class ParallelExecutor(object): ...@@ -30,8 +30,7 @@ class ParallelExecutor(object):
num_threads=None, num_threads=None,
allow_op_delay=False, allow_op_delay=False,
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True, use_default_grad_scale=True):
use_nccl_allreduce=True):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -47,14 +46,6 @@ class ParallelExecutor(object): ...@@ -47,14 +46,6 @@ class ParallelExecutor(object):
improve performance in some cases, default False. improve performance in some cases, default False.
share_vars_from(ParallelExecutor, default None): If provied, share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor. it will share variables from the specified ParallelExecutor.
use_nccl_allreduce(bool, default True): Whether to use nccl_allreduce
or not, if set True, the communication between different
devices by nccl allReduce, which doesn't support updating sparse
parameter, if set False, the communication between different
devices by reduce_op and broadcast_op, which will distribute all
the parameter gradients evenly to different device and updates
the parameters, and finally broadcast to other device, this method
support updating sparse parameter. Default True.
use_default_grad_scale(bool, default True): If set True, a default use_default_grad_scale(bool, default True): If set True, a default
scale value equal to `1./device_count` would be multiplied to scale value equal to `1./device_count` would be multiplied to
gradients of each device and scaled gradients would be gradients of each device and scaled gradients would be
...@@ -138,8 +129,7 @@ class ParallelExecutor(object): ...@@ -138,8 +129,7 @@ class ParallelExecutor(object):
scope, scope,
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale, use_default_grad_scale)
use_nccl_allreduce)
self.scope = scope self.scope = scope
......
...@@ -205,8 +205,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -205,8 +205,7 @@ class TestParallelExecutorBase(unittest.TestCase):
allow_op_delay=False, allow_op_delay=False,
feed_dict=None, feed_dict=None,
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True):
use_nccl_allreduce=True):
def run_executor(exe, feed, fetch_list, program=None): def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor): if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed) res = exe.run(fetch_list=fetch_list, feed=feed)
...@@ -235,10 +234,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -235,10 +234,7 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, True, loss_name=loss.name, allow_op_delay=allow_op_delay)
loss_name=loss.name,
allow_op_delay=allow_op_delay,
use_nccl_allreduce=use_nccl_allreduce)
else: else:
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
...@@ -284,25 +280,20 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -284,25 +280,20 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def check_simple_fc_convergence(self, use_nccl_allreduce=True): def check_simple_fc_convergence(self):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True) self.check_network_convergence(simple_fc_net, allow_op_delay=True)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, simple_fc_net, feed_dict={"image": img,
feed_dict={"image": img, "label": label})
"label": label},
use_nccl_allreduce=use_nccl_allreduce)
def test_simple_fc_with_nccl_allreduce(self): def test_simple_fc(self):
self.check_simple_fc_convergence(True) self.check_simple_fc_convergence()
def test_simple_fc_with_reduce_op(self): def check_simple_fc_parallel_accuracy(self):
self.check_simple_fc_convergence(False)
def check_simple_fc_parallel_accuracy(self, use_nccl_allreduce=True):
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -316,35 +307,26 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -316,35 +307,26 @@ class TestMNIST(TestParallelExecutorBase):
seed=1000, seed=1000,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_parallel_executor=True, use_parallel_executor=True)
use_nccl_allreduce=use_nccl_allreduce)
for p_f in parallel_first_loss: for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6) self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
for p_l in parallel_last_loss: for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_simple_fc_parallel_accuracy_with_nccl_allreduce(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy(True) self.check_simple_fc_parallel_accuracy()
def test_simple_fc_parallel_accuracy_with_reduce_op(self):
self.check_simple_fc_parallel_accuracy(False)
def check_batchnorm_fc_convergence(self, use_nccl_allreduce): def check_batchnorm_fc_convergence(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64') label = numpy.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
fc_with_batchnorm, fc_with_batchnorm, feed_dict={"image": img,
feed_dict={"image": img, "label": label})
"label": label},
use_nccl_allreduce=use_nccl_allreduce)
def test_batchnorm_fc_with_nccl_allreduce(self):
self.check_batchnorm_fc_convergence(True)
def test_batchnorm_fc_with_reduce_op(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(False) self.check_batchnorm_fc_convergence()
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
...@@ -366,21 +348,17 @@ class TestResnet(TestParallelExecutorBase): ...@@ -366,21 +348,17 @@ class TestResnet(TestParallelExecutorBase):
# fluid.recordio_writer.convert_reader_to_recordio_file( # fluid.recordio_writer.convert_reader_to_recordio_file(
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
def check_resnet_convergence(self, use_nccl_allreduce): def check_resnet_convergence(self):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
functools.partial( functools.partial(
SE_ResNeXt50Small, batch_size=batch_size), SE_ResNeXt50Small, batch_size=batch_size),
iter=20, iter=20,
batch_size=batch_size, batch_size=batch_size)
use_nccl_allreduce=use_nccl_allreduce)
def test_resnet_with_nccl_allreduce(self): def test_resnet(self):
self.check_resnet_convergence(True) self.check_resnet_convergence()
def test_resnet_with_reduce_op(self):
self.check_resnet_convergence(False)
class ModelHyperParams(object): class ModelHyperParams(object):
...@@ -544,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -544,7 +522,7 @@ class TestTransformer(TestParallelExecutorBase):
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def check_network_convergence(self, use_nccl_allreduce): def check_network_convergence(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -565,16 +543,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -565,16 +543,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=True, loss_name=loss.name, main_program=main)
loss_name=loss.name,
main_program=main,
use_nccl_allreduce=use_nccl_allreduce)
test_exe = fluid.ParallelExecutor( test_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=True,
main_program=test_program, main_program=test_program,
share_vars_from=train_exe, share_vars_from=train_exe)
use_nccl_allreduce=use_nccl_allreduce)
for i in xrange(5): for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
...@@ -588,11 +562,8 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -588,11 +562,8 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
"Train loss: " + str(train_loss) + "\n Test loss:" + "Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss)) str(test_loss))
def test_parallel_testing_with_nccl_allreduce(self): def test_parallel(self):
self.check_network_convergence(use_nccl_allreduce=True) self.check_network_convergence()
def test_parallel_testing_with_reduce_op(self):
self.check_network_convergence(use_nccl_allreduce=False)
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
...@@ -612,7 +583,7 @@ embedding_name = 'emb' ...@@ -612,7 +583,7 @@ embedding_name = 'emb'
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
is_sparse, use_nccl_allreduce, **ignored): is_sparse, **ignored):
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
...@@ -681,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -681,7 +652,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def check_network_convergence(self, is_sparse, use_nccl_allreduce): def check_network_convergence(self, is_sparse):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -729,10 +700,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -729,10 +700,7 @@ class TestCRFModel(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( pe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
use_cuda=True,
loss_name=avg_cost.name,
use_nccl_allreduce=use_nccl_allreduce)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
...@@ -749,11 +717,7 @@ class TestCRFModel(unittest.TestCase): ...@@ -749,11 +717,7 @@ class TestCRFModel(unittest.TestCase):
fetch_list=[avg_cost.name]))[0] fetch_list=[avg_cost.name]))[0]
def test_update_sparse_parameter(self): def test_update_sparse_parameter(self):
self.check_network_convergence(is_sparse=True, use_nccl_allreduce=False) self.check_network_convergence(is_sparse=True)
def test_update_dense_parameter_with_nccl_allreduce(self):
self.check_network_convergence(is_sparse=False, use_nccl_allreduce=True)
def test_update_dense_parameter_with_reduce_op(self): def test_update_dense_parameter(self):
self.check_network_convergence( self.check_network_convergence(is_sparse=False)
is_sparse=False, use_nccl_allreduce=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册