提交 a1cfd6d8 编写于 作者: C Chen Weihang

fix train reader usage error

上级 57ce8e2f
......@@ -144,33 +144,29 @@ def train(args, train_exe, compiled_prog, build_res, place):
time_begin = time.time()
test_exe = train_exe
logger.info("Begin training")
feed_data = []
for i in range(args.epoch):
try:
for data in train_pyreader():
feed_data.extend(data)
if len(feed_data) == DEV_COUNT:
avg_cost_np, avg_pred_np, pred_label, label = train_exe.run(feed=feed_data, program=compiled_prog, \
fetch_list=fetch_list)
feed_data = []
steps += 1
if steps % int(args.skip_steps) == 0:
time_end = time.time()
used_time = time_end - time_begin
get_score(pred_label, label, eval_phase = "Train")
logger.info('loss is {}'.format(avg_cost_np))
logger.info("epoch: %d, step: %d, speed: %f steps/s" % (i, steps, args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(train_exe, save_path, train_prog)
logger.info("[save]step %d : save at %s" % (steps, save_path))
if steps % args.validation_steps == 0:
if args.do_eval:
evaluate(args, test_exe, build_res["eval_prog"], build_res, place, "eval")
if args.do_test:
evaluate(args, test_exe, build_res["test_prog"], build_res, place, "test")
avg_cost_np, avg_pred_np, pred_label, label = train_exe.run(feed=data, program=compiled_prog, \
fetch_list=fetch_list)
steps += 1
if steps % int(args.skip_steps) == 0:
time_end = time.time()
used_time = time_end - time_begin
get_score(pred_label, label, eval_phase = "Train")
logger.info('loss is {}'.format(avg_cost_np))
logger.info("epoch: %d, step: %d, speed: %f steps/s" % (i, steps, args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(train_exe, save_path, train_prog)
logger.info("[save]step %d : save at %s" % (steps, save_path))
if steps % args.validation_steps == 0:
if args.do_eval:
evaluate(args, test_exe, build_res["eval_prog"], build_res, place, "eval")
if args.do_test:
evaluate(args, test_exe, build_res["test_prog"], build_res, place, "test")
except Exception as e:
logger.exception(str(e))
logger.error("Train error : %s" % str(e))
......@@ -400,14 +396,14 @@ def main(args):
random.seed(args.random_seed)
model_config = ConfigReader.read_conf(args.config_path)
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
place = fluid.cuda_places()
DEV_COUNT = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
place = fluid.cpu_places()
os.environ['CPU_NUM'] = str(args.cpu_num)
DEV_COUNT = args.cpu_num
logger.info("Dev Num is %s" % str(DEV_COUNT))
exe = fluid.Executor(place)
exe = fluid.Executor(place[0])
if args.do_train and args.build_dict:
DataProcesser.build_dict(args.data_dir + "train.txt", args.data_dir)
# read dict
......@@ -436,8 +432,12 @@ def main(args):
if args.do_train:
build_strategy = fluid.compiler.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
compiled_prog = fluid.compiler.CompiledProgram(build_res["train_prog"]).with_data_parallel( \
loss_name=build_res["cost"].name, build_strategy=build_strategy)
loss_name=build_res["cost"].name, build_strategy=build_strategy,
exec_strategy=exec_strategy)
build_res["compiled_prog"] = compiled_prog
train(args, exe, compiled_prog, build_res, place)
if args.do_eval:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册