未验证 提交 9923be5d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10546 from chengduoZH/feature/change_pe_strategy

Balance parameter_opt between cards
...@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale) platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale) const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes) { local_scopes_(local_scopes),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
...@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program); OpDesc *send_op = GetSendOpDesc(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (op->Type() == "send") {
...@@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
}
}
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(var_types, og)) { if (balance_parameter_opt_between_cards_) {
CreateReduceOp(&result, og, 0); CreateReduceOp(&result, og, cur_device_id);
CreateBroadcastOp(&result, og, 0); var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
} }
} }
} }
...@@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. harzaeds need to be handled.
...@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (!balance_parameter_opt_between_cards_) {
return -1;
}
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
}
}
return var_dev_id;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
......
...@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool balance_parameter_opt_between_cards_;
bool use_default_grad_scale_; bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
...@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
......
...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale) bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), use_default_grad_scale); member_->nccl_ctxs_.get(), use_default_grad_scale,
balance_parameter_opt_between_cards);
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(
params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
use_default_grad_scale); use_default_grad_scale, balance_parameter_opt_between_cards);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
......
...@@ -40,7 +40,8 @@ class ParallelExecutor { ...@@ -40,7 +40,8 @@ class ParallelExecutor {
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale); bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes, Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay, bool use_default_grad_scale) { bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards) {
new (&self) ParallelExecutor( new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars, num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes, main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale); allow_op_delay, use_default_grad_scale,
balance_parameter_opt_between_cards);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
...@@ -30,7 +30,8 @@ class ParallelExecutor(object): ...@@ -30,7 +30,8 @@ class ParallelExecutor(object):
num_threads=None, num_threads=None,
allow_op_delay=False, allow_op_delay=False,
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True): use_default_grad_scale=True,
balance_parameter_opt_between_cards=False):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -51,6 +52,9 @@ class ParallelExecutor(object): ...@@ -51,6 +52,9 @@ class ParallelExecutor(object):
gradients of each device and scaled gradients would be gradients of each device and scaled gradients would be
aggregated. Otherwise, a customized scale value should be fed aggregated. Otherwise, a customized scale value should be fed
to the network. to the network.
balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it
is not recommended.
Returns: Returns:
A ParallelExecutor object. A ParallelExecutor object.
...@@ -129,7 +133,8 @@ class ParallelExecutor(object): ...@@ -129,7 +133,8 @@ class ParallelExecutor(object):
scope, scope,
local_scopes, local_scopes,
allow_op_delay, allow_op_delay,
use_default_grad_scale) use_default_grad_scale,
balance_parameter_opt_between_cards)
self.scope = scope self.scope = scope
......
...@@ -205,7 +205,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -205,7 +205,8 @@ class TestParallelExecutorBase(unittest.TestCase):
allow_op_delay=False, allow_op_delay=False,
feed_dict=None, feed_dict=None,
seed=None, seed=None,
use_parallel_executor=True): use_parallel_executor=True,
balance_parameter_opt_between_cards=False):
def run_executor(exe, feed, fetch_list, program=None): def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor): if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed) res = exe.run(fetch_list=fetch_list, feed=feed)
...@@ -234,7 +235,11 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -234,7 +235,11 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, loss_name=loss.name, allow_op_delay=allow_op_delay) True,
loss_name=loss.name,
allow_op_delay=allow_op_delay,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
else: else:
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
...@@ -280,20 +285,27 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -280,20 +285,27 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def check_simple_fc_convergence(self): def check_simple_fc_convergence(self, balance_parameter_opt_between_cards):
self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True) self.check_network_convergence(simple_fc_net, allow_op_delay=True)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, feed_dict={"image": img, simple_fc_net,
"label": label}) feed_dict={"image": img,
"label": label},
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
def test_simple_fc(self): def test_simple_fc(self):
self.check_simple_fc_convergence() self.check_simple_fc_convergence(False)
def test_simple_fc_with_new_strategy(self):
self.check_simple_fc_convergence(True)
def check_simple_fc_parallel_accuracy(self): def check_simple_fc_parallel_accuracy(self,
balance_parameter_opt_between_cards):
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
...@@ -307,7 +319,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -307,7 +319,9 @@ class TestMNIST(TestParallelExecutorBase):
seed=1000, seed=1000,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_parallel_executor=True) use_parallel_executor=True,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
for p_f in parallel_first_loss: for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6) self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
...@@ -315,18 +329,28 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -315,18 +329,28 @@ class TestMNIST(TestParallelExecutorBase):
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_simple_fc_parallel_accuracy(self): def test_simple_fc_parallel_accuracy(self):
self.check_simple_fc_parallel_accuracy() self.check_simple_fc_parallel_accuracy(False)
def check_batchnorm_fc_convergence(self): def test_simple_fc_parallel_accuracy_with_new_strategy(self):
self.check_simple_fc_parallel_accuracy(True)
def check_batchnorm_fc_convergence(self,
balance_parameter_opt_between_cards):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = np.zeros(shape=[32, 784], dtype='float32') img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
fc_with_batchnorm, feed_dict={"image": img, fc_with_batchnorm,
"label": label}) feed_dict={"image": img,
"label": label},
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence() self.check_batchnorm_fc_convergence(False)
def test_batchnorm_fc_with_new_strategy(self):
self.check_batchnorm_fc_convergence(True)
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
...@@ -348,17 +372,22 @@ class TestResnet(TestParallelExecutorBase): ...@@ -348,17 +372,22 @@ class TestResnet(TestParallelExecutorBase):
# fluid.recordio_writer.convert_reader_to_recordio_file( # fluid.recordio_writer.convert_reader_to_recordio_file(
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
def check_resnet_convergence(self): def check_resnet_convergence(self, balance_parameter_opt_between_cards):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
functools.partial( functools.partial(
SE_ResNeXt50Small, batch_size=batch_size), SE_ResNeXt50Small, batch_size=batch_size),
iter=20, iter=20,
batch_size=batch_size) batch_size=batch_size,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
def test_resnet(self): def test_resnet(self):
self.check_resnet_convergence() self.check_resnet_convergence(False)
def test_resnet_with_new_strategy(self):
self.check_resnet_convergence(True)
class ModelHyperParams(object): class ModelHyperParams(object):
...@@ -519,7 +548,7 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -519,7 +548,7 @@ class TestTransformer(TestParallelExecutorBase):
class ParallelExecutorTestingDuringTraining(unittest.TestCase): class ParallelExecutorTestingDuringTraining(unittest.TestCase):
def check_network_convergence(self): def check_network_convergence(self, balance_parameter_opt_between_cards):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -539,12 +568,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -539,12 +568,18 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main) use_cuda=True,
loss_name=loss.name,
main_program=main,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
test_exe = fluid.ParallelExecutor( test_exe = fluid.ParallelExecutor(
use_cuda=True, use_cuda=True,
main_program=test_program, main_program=test_program,
share_vars_from=train_exe) share_vars_from=train_exe,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
for i in xrange(5): for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
...@@ -558,8 +593,11 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -558,8 +593,11 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
"Train loss: " + str(train_loss) + "\n Test loss:" + "Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss)) str(test_loss))
def test_parallel(self): def test_parallel_testing(self):
self.check_network_convergence() self.check_network_convergence(False)
def test_parallel_testing_with_new_strategy(self):
self.check_network_convergence(True)
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
...@@ -579,7 +617,7 @@ embedding_name = 'emb' ...@@ -579,7 +617,7 @@ embedding_name = 'emb'
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
is_sparse, **ignored): is_sparse, balance_parameter_opt_between_cards, **ignored):
# 8 features # 8 features
predicate_embedding = fluid.layers.embedding( predicate_embedding = fluid.layers.embedding(
input=predicate, input=predicate,
...@@ -648,7 +686,9 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, ...@@ -648,7 +686,9 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
class TestCRFModel(unittest.TestCase): class TestCRFModel(unittest.TestCase):
def check_network_convergence(self, is_sparse): def check_network_convergence(self,
is_sparse,
balance_parameter_opt_between_cards=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -696,7 +736,11 @@ class TestCRFModel(unittest.TestCase): ...@@ -696,7 +736,11 @@ class TestCRFModel(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) pe = fluid.ParallelExecutor(
use_cuda=True,
loss_name=avg_cost.name,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
...@@ -718,6 +762,14 @@ class TestCRFModel(unittest.TestCase): ...@@ -718,6 +762,14 @@ class TestCRFModel(unittest.TestCase):
def test_update_dense_parameter(self): def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False) self.check_network_convergence(is_sparse=False)
def test_update_sparse_parameter_with_new_strategy(self):
self.check_network_convergence(
is_sparse=False, balance_parameter_opt_between_cards=True)
def test_update_dense_parameter_with_new_strategy(self):
self.check_network_convergence(
is_sparse=False, balance_parameter_opt_between_cards=True)
# test fetch all the variables of global_block # test fetch all the variables of global_block
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册