提交 dd45f529 编写于 作者: W whs 提交者: qingqing01

Fix evaluator in parallel mode (#923)

上级 dc2cc7d2
...@@ -71,6 +71,7 @@ def train(args, data_reader=ctc_reader): ...@@ -71,6 +71,7 @@ def train(args, data_reader=ctc_reader):
print "Init model from: %s." % args.init_model print "Init model from: %s." % args.init_model
train_exe = exe train_exe = exe
error_evaluator.reset(exe)
if args.parallel: if args.parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=sum_cost.name) use_cuda=True, loss_name=sum_cost.name)
...@@ -81,7 +82,7 @@ def train(args, data_reader=ctc_reader): ...@@ -81,7 +82,7 @@ def train(args, data_reader=ctc_reader):
var_names = [var.name for var in fetch_vars] var_names = [var.name for var in fetch_vars]
if args.parallel: if args.parallel:
results = train_exe.run(var_names, results = train_exe.run(var_names,
feed_dict=get_feeder_data(data, place)) feed=get_feeder_data(data, place))
results = [np.array(result).sum() for result in results] results = [np.array(result).sum() for result in results]
else: else:
results = exe.run(feed=get_feeder_data(data, place), results = exe.run(feed=get_feeder_data(data, place),
...@@ -103,7 +104,6 @@ def train(args, data_reader=ctc_reader): ...@@ -103,7 +104,6 @@ def train(args, data_reader=ctc_reader):
exe, dirname=args.save_model_dir, filename=filename) exe, dirname=args.save_model_dir, filename=filename)
print "Saved model to: %s/%s." % (args.save_model_dir, filename) print "Saved model to: %s/%s." % (args.save_model_dir, filename)
error_evaluator.reset(exe)
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
batch_id = 1 batch_id = 1
total_loss = 0.0 total_loss = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册