提交 dd45f529 编写于 作者: W whs 提交者: qingqing01

Fix evaluator in parallel mode (#923)

上级 dc2cc7d2
......@@ -71,6 +71,7 @@ def train(args, data_reader=ctc_reader):
print "Init model from: %s." % args.init_model
train_exe = exe
error_evaluator.reset(exe)
if args.parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=True, loss_name=sum_cost.name)
......@@ -81,7 +82,7 @@ def train(args, data_reader=ctc_reader):
var_names = [var.name for var in fetch_vars]
if args.parallel:
results = train_exe.run(var_names,
feed_dict=get_feeder_data(data, place))
feed=get_feeder_data(data, place))
results = [np.array(result).sum() for result in results]
else:
results = exe.run(feed=get_feeder_data(data, place),
......@@ -103,7 +104,6 @@ def train(args, data_reader=ctc_reader):
exe, dirname=args.save_model_dir, filename=filename)
print "Saved model to: %s/%s." % (args.save_model_dir, filename)
error_evaluator.reset(exe)
for pass_id in range(args.pass_num):
batch_id = 1
total_loss = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册