未验证 提交 b67ce353 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

speed up test label semantic roles (#10718)

上级 14248a64
......@@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
......@@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
def train_loop(main_program):
exe.run(fluid.default_startup_program())
embedding_param = fluid.global_scope().find_var(
embedding_name).get_tensor()
embedding_param.set(
......@@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True):
start_time = time.time()
batch_id = 0
for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
for data in train_data():
cost, precision, recall, f1_score = exe.run(
main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(
exe)
cost = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
cost = cost[0]
if batch_id % 10 == 0:
print("avg_cost:" + str(cost) + " precision:" + str(
precision) + " recall:" + str(recall) + " f1_score:" +
str(f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(
pass_recall) + " pass_f1_score:" + str(
pass_f1_score))
print("avg_cost:" + str(cost))
if batch_id != 0:
print("second per batch: " + str((time.time(
) - start_time) / batch_id))
# Set the threshold low to speed up the CI test
if float(pass_precision) > 0.01:
if float(cost) < 60.0:
if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model(save_dirname, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册