提交 6ed5f04d 编写于 作者: SYSU_BOND's avatar SYSU_BOND 提交者: bbking

replace open with io.open to be compatible with windows (#3707)

* update downloads.py

* fix bug on ernie based inferring

* replace open with io.open to be compatible with  windows
上级 68d6379c
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
evaluate wordseg for LAC and other open-source wordseg tools evaluate wordseg for LAC and other open-source wordseg tools
""" """
...@@ -21,6 +20,7 @@ from __future__ import division ...@@ -21,6 +20,7 @@ from __future__ import division
import sys import sys
import os import os
import io
def to_unicode(string): def to_unicode(string):
...@@ -71,7 +71,7 @@ def load_testdata(datapath="./data/test_data/test_part"): ...@@ -71,7 +71,7 @@ def load_testdata(datapath="./data/test_data/test_part"):
"""none""" """none"""
sentences = [] sentences = []
sent_seg_list = [] sent_seg_list = []
for line in open(datapath): for line in io.open(datapath, 'r', encoding='utf8'):
sent, label = line.strip().split("\t") sent, label = line.strip().split("\t")
sentences.append(sent) sentences.append(sent)
...@@ -110,7 +110,7 @@ def get_lac_result(): ...@@ -110,7 +110,7 @@ def get_lac_result():
`sh run.sh | tail -n 100 > result.txt` `sh run.sh | tail -n 100 > result.txt`
""" """
sent_seg_list = [] sent_seg_list = []
for line in open("./result.txt"): for line in io.open("./result.txt", 'r', encoding='utf8'):
line = line.strip().split(" ") line = line.strip().split(" ")
words = [pair.split("/")[0] for pair in line] words = [pair.split("/")[0] for pair in line]
labels = [pair.split("/")[1] for pair in line] labels = [pair.split("/")[1] for pair in line]
......
...@@ -31,20 +31,31 @@ from model_check import check_version ...@@ -31,20 +31,31 @@ from model_check import check_version
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
# 1. model parameters # 1. model parameters
model_g = utils.ArgumentGroup(parser, "model", "model configuration") model_g = utils.ArgumentGroup(parser, "model", "model configuration")
model_g.add_arg("word_emb_dim", int, 128, "The dimension in which a word is embedded.") model_g.add_arg("word_emb_dim", int, 128,
model_g.add_arg("grnn_hidden_dim", int, 128, "The number of hidden nodes in the GRNN layer.") "The dimension in which a word is embedded.")
model_g.add_arg("bigru_num", int, 2, "The number of bi_gru layers in the network.") model_g.add_arg("grnn_hidden_dim", int, 128,
"The number of hidden nodes in the GRNN layer.")
model_g.add_arg("bigru_num", int, 2,
"The number of bi_gru layers in the network.")
model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
# 2. data parameters # 2. data parameters
data_g = utils.ArgumentGroup(parser, "data", "data paths") data_g = utils.ArgumentGroup(parser, "data", "data paths")
data_g.add_arg("word_dict_path", str, "./conf/word.dic", "The path of the word dictionary.") data_g.add_arg("word_dict_path", str, "./conf/word.dic",
data_g.add_arg("label_dict_path", str, "./conf/tag.dic", "The path of the label dictionary.") "The path of the word dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic", "The path of the word replacement Dictionary.") data_g.add_arg("label_dict_path", str, "./conf/tag.dic",
data_g.add_arg("test_data", str, "./data/test.tsv", "The folder where the training data is located.") "The path of the label dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic",
"The path of the word replacement Dictionary.")
data_g.add_arg("test_data", str, "./data/test.tsv",
"The folder where the training data is located.")
data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model") data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model")
data_g.add_arg("batch_size", int, 200, "The number of sequences contained in a mini-batch, " data_g.add_arg(
"or the maximum number of tokens (include paddings) contained in a mini-batch.") "batch_size", int, 200,
"The number of sequences contained in a mini-batch, "
"or the maximum number of tokens (include paddings) contained in a mini-batch."
)
def do_eval(args): def do_eval(args):
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
...@@ -62,23 +73,23 @@ def do_eval(args): ...@@ -62,23 +73,23 @@ def do_eval(args):
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
pyreader = creator.create_pyreader(args, file_name=args.test_data, pyreader = creator.create_pyreader(
feed_list=test_ret['feed_list'], args,
place=place, file_name=args.test_data,
model='lac', feed_list=test_ret['feed_list'],
reader=dataset, place=place,
mode='test') model='lac',
reader=dataset,
mode='test')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load model # load model
utils.init_checkpoint(exe, args.init_checkpoint, test_program) utils.init_checkpoint(exe, args.init_checkpoint, test_program)
test_process(exe=exe, test_process(
program=test_program, exe=exe, program=test_program, reader=pyreader, test_ret=test_ret)
reader=pyreader,
test_ret=test_ret
)
def test_process(exe, program, reader, test_ret): def test_process(exe, program, reader, test_ret):
""" """
...@@ -93,20 +104,21 @@ def test_process(exe, program, reader, test_ret): ...@@ -93,20 +104,21 @@ def test_process(exe, program, reader, test_ret):
start_time = time.time() start_time = time.time()
for data in reader(): for data in reader():
nums_infer, nums_label, nums_correct = exe.run(program, nums_infer, nums_label, nums_correct = exe.run(
fetch_list=[ program,
test_ret["num_infer_chunks"], fetch_list=[
test_ret["num_label_chunks"], test_ret["num_infer_chunks"],
test_ret["num_correct_chunks"], test_ret["num_label_chunks"],
], test_ret["num_correct_chunks"],
feed=data, ],
) feed=data, )
test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct) test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct)
precision, recall, f1 = test_ret["chunk_evaluator"].eval() precision, recall, f1 = test_ret["chunk_evaluator"].eval()
end_time = time.time() end_time = time.time()
print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" print("[test] P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" %
% (precision, recall, f1, end_time - start_time)) (precision, recall, f1, end_time - start_time))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
......
...@@ -14,6 +14,7 @@ sys.path.append('../models/') ...@@ -14,6 +14,7 @@ sys.path.append('../models/')
from model_check import check_cuda from model_check import check_cuda
from model_check import check_version from model_check import check_version
def save_inference_model(args): def save_inference_model(args):
# model definition # model definition
...@@ -30,20 +31,19 @@ def save_inference_model(args): ...@@ -30,20 +31,19 @@ def save_inference_model(args):
args, dataset.vocab_size, dataset.num_labels, mode='infer') args, dataset.vocab_size, dataset.num_labels, mode='infer')
infer_program = infer_program.clone(for_test=True) infer_program = infer_program.clone(for_test=True)
# load pretrain check point # load pretrain check point
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
utils.init_checkpoint(exe, args.init_checkpoint, infer_program) utils.init_checkpoint(exe, args.init_checkpoint, infer_program)
fluid.io.save_inference_model(args.inference_save_dir, fluid.io.save_inference_model(
['words'], args.inference_save_dir,
infer_ret['crf_decode'], ['words'],
exe, infer_ret['crf_decode'],
main_program=infer_program, exe,
model_filename='model.pdmodel', main_program=infer_program,
params_filename='params.pdparams', model_filename='model.pdmodel',
) params_filename='params.pdparams', )
def test_inference_model(model_dir, text_list, dataset): def test_inference_model(model_dir, text_list, dataset):
...@@ -68,45 +68,46 @@ def test_inference_model(model_dir, text_list, dataset): ...@@ -68,45 +68,46 @@ def test_inference_model(model_dir, text_list, dataset):
tensor_words = fluid.create_lod_tensor(lod, base_shape, place) tensor_words = fluid.create_lod_tensor(lod, base_shape, place)
# for empty input, output the same empty # for empty input, output the same empty
if(sum(base_shape[0]) == 0 ): if (sum(base_shape[0]) == 0):
crf_decode = [tensor_words] crf_decode = [tensor_words]
else: else:
# load inference model # load inference model
inference_scope = fluid.core.Scope() inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
[inferencer, feed_target_names, [inferencer, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_dir, exe, fetch_targets] = fluid.io.load_inference_model(
model_filename='model.pdmodel', model_dir,
params_filename='params.pdparams', exe,
) model_filename='model.pdmodel',
params_filename='params.pdparams', )
assert feed_target_names[0] == "words" assert feed_target_names[0] == "words"
print("Load inference model from %s"%(model_dir)) print("Load inference model from %s" % (model_dir))
# get lac result # get lac result
crf_decode = exe.run(inferencer, crf_decode = exe.run(
feed={feed_target_names[0]:tensor_words}, inferencer,
fetch_list=fetch_targets, feed={feed_target_names[0]: tensor_words},
return_numpy=False, fetch_list=fetch_targets,
use_program_cache=True, return_numpy=False,
) use_program_cache=True, )
# parse the crf_decode result # parse the crf_decode result
result = utils.parse_result(tensor_words,crf_decode[0], dataset) result = utils.parse_result(tensor_words, crf_decode[0], dataset)
for i,(sent, tags) in enumerate(result): for i, (sent, tags) in enumerate(result):
result_list = ['(%s, %s)'%(ch, tag) for ch, tag in zip(sent,tags)] result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
print(''.join(result_list)) print(''.join(result_list))
if __name__=="__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser,'conf/args.yaml') utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version() check_version()
print("save inference model") print("save inference model")
save_inference_model(args) save_inference_model(args)
print("inference model save in %s"%args.inference_save_dir) print("inference model save in %s" % args.inference_save_dir)
print("test inference model") print("test inference model")
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
test_data = [u'百度是一家高科技公司', u'中山大学是岭南第一学府'] test_data = [u'百度是一家高科技公司', u'中山大学是岭南第一学府']
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os import os
import time import time
...@@ -30,20 +31,31 @@ from model_check import check_version ...@@ -30,20 +31,31 @@ from model_check import check_version
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
# 1. model parameters # 1. model parameters
model_g = utils.ArgumentGroup(parser, "model", "model configuration") model_g = utils.ArgumentGroup(parser, "model", "model configuration")
model_g.add_arg("word_emb_dim", int, 128, "The dimension in which a word is embedded.") model_g.add_arg("word_emb_dim", int, 128,
model_g.add_arg("grnn_hidden_dim", int, 256, "The number of hidden nodes in the GRNN layer.") "The dimension in which a word is embedded.")
model_g.add_arg("bigru_num", int, 2, "The number of bi_gru layers in the network.") model_g.add_arg("grnn_hidden_dim", int, 256,
"The number of hidden nodes in the GRNN layer.")
model_g.add_arg("bigru_num", int, 2,
"The number of bi_gru layers in the network.")
model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") model_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
# 2. data parameters # 2. data parameters
data_g = utils.ArgumentGroup(parser, "data", "data paths") data_g = utils.ArgumentGroup(parser, "data", "data paths")
data_g.add_arg("word_dict_path", str, "./conf/word.dic", "The path of the word dictionary.") data_g.add_arg("word_dict_path", str, "./conf/word.dic",
data_g.add_arg("label_dict_path", str, "./conf/tag.dic", "The path of the label dictionary.") "The path of the word dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic", "The path of the word replacement Dictionary.") data_g.add_arg("label_dict_path", str, "./conf/tag.dic",
data_g.add_arg("infer_data", str, "./data/infer.tsv", "The folder where the training data is located.") "The path of the label dictionary.")
data_g.add_arg("word_rep_dict_path", str, "./conf/q2b.dic",
"The path of the word replacement Dictionary.")
data_g.add_arg("infer_data", str, "./data/infer.tsv",
"The folder where the training data is located.")
data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model") data_g.add_arg("init_checkpoint", str, "./model_baseline", "Path to init model")
data_g.add_arg("batch_size", int, 200, "The number of sequences contained in a mini-batch, " data_g.add_arg(
"or the maximum number of tokens (include paddings) contained in a mini-batch.") "batch_size", int, 200,
"The number of sequences contained in a mini-batch, "
"or the maximum number of tokens (include paddings) contained in a mini-batch."
)
def do_infer(args): def do_infer(args):
dataset = reader.Dataset(args) dataset = reader.Dataset(args)
...@@ -61,14 +73,14 @@ def do_infer(args): ...@@ -61,14 +73,14 @@ def do_infer(args):
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
pyreader = creator.create_pyreader(
args,
pyreader = creator.create_pyreader(args, file_name=args.infer_data, file_name=args.infer_data,
feed_list=infer_ret['feed_list'], feed_list=infer_ret['feed_list'],
place=place, place=place,
model='lac', model='lac',
reader=dataset, reader=dataset,
mode='infer') mode='infer')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -81,8 +93,7 @@ def do_infer(args): ...@@ -81,8 +93,7 @@ def do_infer(args):
program=infer_program, program=infer_program,
reader=pyreader, reader=pyreader,
fetch_vars=[infer_ret['words'], infer_ret['crf_decode']], fetch_vars=[infer_ret['words'], infer_ret['crf_decode']],
dataset=dataset dataset=dataset)
)
for sent, tags in result: for sent, tags in result:
result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)] result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
print(''.join(result_list)) print(''.join(result_list))
...@@ -96,8 +107,9 @@ def infer_process(exe, program, reader, fetch_vars, dataset): ...@@ -96,8 +107,9 @@ def infer_process(exe, program, reader, fetch_vars, dataset):
:param reader: data reader :param reader: data reader
:return: the list of prediction result :return: the list of prediction result
""" """
def input_check(data): def input_check(data):
if data[0]['words'].lod()[0][-1]==0: if data[0]['words'].lod()[0][-1] == 0:
return data[0]['words'] return data[0]['words']
return None return None
...@@ -108,17 +120,17 @@ def infer_process(exe, program, reader, fetch_vars, dataset): ...@@ -108,17 +120,17 @@ def infer_process(exe, program, reader, fetch_vars, dataset):
results += utils.parse_result(crf_decode, crf_decode, dataset) results += utils.parse_result(crf_decode, crf_decode, dataset)
continue continue
words, crf_decode = exe.run(program, words, crf_decode = exe.run(
fetch_list=fetch_vars, program,
feed=data, fetch_list=fetch_vars,
return_numpy=False, feed=data,
use_program_cache=True, return_numpy=False,
) use_program_cache=True, )
results += utils.parse_result(words, crf_decode, dataset) results += utils.parse_result(words, crf_decode, dataset)
return results return results
if __name__=="__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
check_version() check_version()
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
""" """
The file_reader converts raw corpus to input. The file_reader converts raw corpus to input.
""" """
import os import os
import argparse import argparse
import __future__ import __future__
......
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import yaml import yaml
import io
def str2bool(v): def str2bool(v):
...@@ -50,7 +51,7 @@ class ArgumentGroup(object): ...@@ -50,7 +51,7 @@ class ArgumentGroup(object):
def load_yaml(parser, file_name, **kwargs): def load_yaml(parser, file_name, **kwargs):
with open(file_name) as f: with io.open(file_name, 'r', encoding='utf8') as f:
args = yaml.load(f) args = yaml.load(f)
for title in args: for title in args:
group = parser.add_argument_group(title=title, description='') group = parser.add_argument_group(title=title, description='')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册