提交 94fc64bc 编写于 作者: A aprilvkuo 提交者: pkpk

fixed test/eval func usage (#4176)

* fix train reader usage error

* fix cpu num set error

* fixed eval/test func

* change pred/eval to 1 thread or 1 gpu card
Co-authored-by: NChen Weihang <sunny_cwh@163.com>
上级 300be16c
......@@ -122,7 +122,7 @@ def get_score(pred_result, label, eval_phase):
recall, f1, acc))
def train(args, train_exe, compiled_prog, build_res, place):
def train(args, train_exe, build_res, place):
"""[train the net]
Arguments:
......@@ -133,6 +133,7 @@ def train(args, train_exe, compiled_prog, build_res, place):
place {[type]} -- [description]
"""
global DEV_COUNT
compiled_prog = build_res["compiled_prog"]
cost = build_res["cost"]
prediction = build_res["prediction"]
pred_label = build_res["pred_label"]
......@@ -144,33 +145,29 @@ def train(args, train_exe, compiled_prog, build_res, place):
time_begin = time.time()
test_exe = train_exe
logger.info("Begin training")
feed_data = []
for i in range(args.epoch):
try:
for data in train_pyreader():
feed_data.extend(data)
if len(feed_data) == DEV_COUNT:
avg_cost_np, avg_pred_np, pred_label, label = train_exe.run(feed=feed_data, program=compiled_prog, \
fetch_list=fetch_list)
feed_data = []
steps += 1
if steps % int(args.skip_steps) == 0:
time_end = time.time()
used_time = time_end - time_begin
get_score(pred_label, label, eval_phase = "Train")
logger.info('loss is {}'.format(avg_cost_np))
logger.info("epoch: %d, step: %d, speed: %f steps/s" % (i, steps, args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(train_exe, save_path, train_prog)
logger.info("[save]step %d : save at %s" % (steps, save_path))
if steps % args.validation_steps == 0:
if args.do_eval:
evaluate(args, test_exe, build_res["eval_prog"], build_res, place, "eval")
if args.do_test:
evaluate(args, test_exe, build_res["test_prog"], build_res, place, "test")
avg_cost_np, avg_pred_np, pred_label, label = train_exe.run(feed=data, program=compiled_prog, \
fetch_list=fetch_list)
steps += 1
if steps % int(args.skip_steps) == 0:
time_end = time.time()
used_time = time_end - time_begin
get_score(pred_label, label, eval_phase = "Train")
logger.info('loss is {}'.format(avg_cost_np))
logger.info("epoch: %d, step: %d, speed: %f steps/s" % (i, steps, args.skip_steps / used_time))
time_begin = time.time()
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(train_exe, save_path, train_prog)
logger.info("[save]step %d : save at %s" % (steps, save_path))
if steps % args.validation_steps == 0:
if args.do_eval:
evaluate(args, test_exe, build_res, "eval")
if args.do_test:
evaluate(args, test_exe, build_res, "test")
except Exception as e:
logger.exception(str(e))
logger.error("Train error : %s" % str(e))
......@@ -180,7 +177,7 @@ def train(args, train_exe, compiled_prog, build_res, place):
logger.info("[save]step %d : save at %s" % (steps, save_path))
def evaluate(args, test_exe, test_prog, build_res, place, eval_phase, save_result=False, id2intent=None):
def evaluate(args, test_exe, build_res, eval_phase, save_result=False, id2intent=None):
"""[evaluate on dev/test dataset]
Arguments:
......@@ -196,6 +193,7 @@ def evaluate(args, test_exe, test_prog, build_res, place, eval_phase, save_resul
save_result {bool} -- [description] (default: {False})
id2intent {[type]} -- [description] (default: {None})
"""
place = build_res["test_place"]
threshold = args.threshold
cost = build_res["cost"]
prediction = build_res["prediction"]
......@@ -204,8 +202,10 @@ def evaluate(args, test_exe, test_prog, build_res, place, eval_phase, save_resul
fetch_list = [cost.name, prediction.name, pred_label.name, label.name]
total_cost, total_acc, pred_prob_list, pred_label_list, label_list = [], [], [], [], []
if eval_phase == "eval":
test_prog = build_res["eval_compiled_prog"]
test_pyreader = build_res["eval_pyreader"]
elif eval_phase == "test":
test_prog = build_res["test_compiled_prog"]
test_pyreader = build_res["test_pyreader"]
else:
exit(1)
......@@ -332,7 +332,7 @@ def build_data_reader(args, char_dict, intent_dict):
return reader_res
def build_graph(args, model_config, num_labels, dict_dim, place, reader_res):
def build_graph(args, model_config, num_labels, dict_dim, place, test_place, reader_res):
"""[build paddle graph]
Arguments:
......@@ -369,15 +369,15 @@ def build_graph(args, model_config, num_labels, dict_dim, place, reader_res):
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
eval_pyreader, cost, prediction, pred_label, label = create_net(args, model_config, num_labels, \
dict_dim, place, model_name="textcnn_net")
eval_pyreader.decorate_sample_list_generator(reader_res['eval_data_generator'], places=place)
dict_dim, test_place, model_name="textcnn_net")
eval_pyreader.decorate_sample_list_generator(reader_res['eval_data_generator'], places=test_place)
res["eval_pyreader"] = eval_pyreader
if args.do_test:
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
test_pyreader, cost, prediction, pred_label, label = create_net(args, model_config, num_labels, \
dict_dim, place, model_name="textcnn_net")
test_pyreader.decorate_sample_list_generator(reader_res['test_data_generator'], places=place)
dict_dim, test_place, model_name="textcnn_net")
test_pyreader.decorate_sample_list_generator(reader_res['test_data_generator'], places=test_place)
res["test_pyreader"] = test_pyreader
res["cost"] = cost
res["prediction"] = prediction
......@@ -400,14 +400,16 @@ def main(args):
random.seed(args.random_seed)
model_config = ConfigReader.read_conf(args.config_path)
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
test_place = fluid.cuda_places(0)
place = fluid.cuda_places()
DEV_COUNT = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
test_place = fluid.cpu_places(1)
os.environ['CPU_NUM'] = str(args.cpu_num)
place = fluid.cpu_places()
DEV_COUNT = args.cpu_num
logger.info("Dev Num is %s" % str(DEV_COUNT))
exe = fluid.Executor(place)
exe = fluid.Executor(place[0])
if args.do_train and args.build_dict:
DataProcesser.build_dict(args.data_dir + "train.txt", args.data_dir)
# read dict
......@@ -420,7 +422,9 @@ def main(args):
num_labels = len(intent_dict)
# build model
reader_res = build_data_reader(args, char_dict, intent_dict)
build_res = build_graph(args, model_config, num_labels, dict_dim, place, reader_res)
build_res = build_graph(args, model_config, num_labels, dict_dim, place, test_place, reader_res)
build_res["place"] = place
build_res["test_place"] = test_place
if not (args.do_train or args.do_eval or args.do_test):
raise ValueError("For args `do_train`, `do_eval` and `do_test`, at "
"least one of them must be True.")
......@@ -433,24 +437,34 @@ def main(args):
except Exception as e:
logger.exception(str(e))
logger.error("Faild load model from %s [%s]" % (args.init_checkpoint, str(e)))
build_strategy = fluid.compiler.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
# add compiled prog
if args.do_train:
build_strategy = fluid.compiler.BuildStrategy()
compiled_prog = fluid.compiler.CompiledProgram(build_res["train_prog"]).with_data_parallel( \
loss_name=build_res["cost"].name, build_strategy=build_strategy)
loss_name=build_res["cost"].name, \
build_strategy=build_strategy, \
exec_strategy=exec_strategy)
build_res["compiled_prog"] = compiled_prog
train(args, exe, compiled_prog, build_res, place)
if args.do_test:
test_compiled_prog = fluid.compiler.CompiledProgram(build_res["test_prog"])
build_res["test_compiled_prog"] = test_compiled_prog
if args.do_eval:
evaluate(args, exe, build_res["eval_prog"], build_res, place, "eval", \
eval_compiled_prog = fluid.compiler.CompiledProgram(build_res["eval_prog"])
build_res["eval_compiled_prog"] = eval_compiled_prog
if args.do_train:
train(args, exe, build_res, place)
if args.do_eval:
evaluate(args, exe, build_res, "eval", \
save_result=True, id2intent=id2intent)
if args.do_test:
evaluate(args, exe, build_res["test_prog"], build_res, place, "test",\
evaluate(args, exe, build_res, "test",\
save_result=True, id2intent=id2intent)
if __name__ == "__main__":
logger.info("the paddle version is %s" % paddle.__version__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册