提交 1fbae957 编写于 作者: B bingyanghuang

fix chinese_ner inference appearing linear_chain_crf_grad op

上级 beb4dfbe
......@@ -61,6 +61,21 @@ def load_reverse_dict(dict_path):
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def infer(args):
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1)
......@@ -93,9 +108,13 @@ def infer(args):
profiler.reset_profiler()
iters = 0
for data in test_data():
word = to_lodtensor(map(lambda x: x[0], data), place)
mention = to_lodtensor(map(lambda x: x[1], data), place)
start = time.time()
crf_decode = exe.run(inference_program,
feed=feeder.feed(data),
feed={"word": word,
"mention": mention},
fetch_list=fetch_targets,
return_numpy=False)
batch_time = time.time() - start
......
......@@ -266,12 +266,13 @@ def main(args):
with fluid.program_guard(main, startup):
avg_cost, feature_out, word, mention, target = ner_net(
args.word_dict_len, args.label_dict_len)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
......@@ -348,8 +349,8 @@ def main(args):
+ str(f1))
save_dirname = os.path.join(args.model_save_dir,
"params_pass_%d" % pass_id)
fluid.io.save_inference_model(
save_dirname, ['word', 'mention', 'target'], [crf_decode], exe)
fluid.io.save_inference_model(save_dirname, ['word', 'mention'],
[crf_decode], exe)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册