未验证 提交 b67ce353 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

speed up test label semantic roles (#10718)

上级 14248a64
...@@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192), paddle.dataset.conll05.test(), buf_size=8192),
...@@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
def train_loop(main_program): def train_loop(main_program):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
embedding_param = fluid.global_scope().find_var( embedding_param = fluid.global_scope().find_var(
embedding_name).get_tensor() embedding_name).get_tensor()
embedding_param.set( embedding_param.set(
...@@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True): ...@@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True):
start_time = time.time() start_time = time.time()
batch_id = 0 batch_id = 0
for pass_id in xrange(PASS_NUM): for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
for data in train_data(): for data in train_data():
cost, precision, recall, f1_score = exe.run( cost = exe.run(main_program,
main_program, feed=feeder.feed(data),
feed=feeder.feed(data), fetch_list=[avg_cost])
fetch_list=[avg_cost] + chunk_evaluator.metrics) cost = cost[0]
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(
exe)
if batch_id % 10 == 0: if batch_id % 10 == 0:
print("avg_cost:" + str(cost) + " precision:" + str( print("avg_cost:" + str(cost))
precision) + " recall:" + str(recall) + " f1_score:" +
str(f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(
pass_recall) + " pass_f1_score:" + str(
pass_f1_score))
if batch_id != 0: if batch_id != 0:
print("second per batch: " + str((time.time( print("second per batch: " + str((time.time(
) - start_time) / batch_id)) ) - start_time) / batch_id))
# Set the threshold low to speed up the CI test # Set the threshold low to speed up the CI test
if float(pass_precision) > 0.01: if float(cost) < 60.0:
if save_dirname is not None: if save_dirname is not None:
# TODO(liuyiqun): Change the target to crf_decode # TODO(liuyiqun): Change the target to crf_decode
fluid.io.save_inference_model(save_dirname, [ fluid.io.save_inference_model(save_dirname, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册