提交 8db0319c 编写于 作者: SYSU_BOND's avatar SYSU_BOND 提交者: bbking

Fix infer bug on Release/1.6 (#3693)

* update downloads.py

* fix bug on ernie based inferring
上级 03f81264
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Define the function to create lexical analysis model and model's data reader Define the function to create lexical analysis model and model's data reader
""" """
...@@ -37,22 +36,29 @@ def create_model(args, vocab_size, num_labels, mode='train'): ...@@ -37,22 +36,29 @@ def create_model(args, vocab_size, num_labels, mode='train'):
# model's input data # model's input data
words = fluid.data(name='words', shape=[-1, 1], dtype='int64', lod_level=1) words = fluid.data(name='words', shape=[-1, 1], dtype='int64', lod_level=1)
targets = fluid.data(name='targets', shape=[-1, 1], dtype='int64', lod_level=1) targets = fluid.data(
name='targets', shape=[-1, 1], dtype='int64', lod_level=1)
# for inference process # for inference process
if mode == 'infer': if mode == 'infer':
crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=True, target=None) crf_decode = nets.lex_net(
return {"feed_list": [words], "words": words, "crf_decode": crf_decode, } words, args, vocab_size, num_labels, for_infer=True, target=None)
return {
"feed_list": [words],
"words": words,
"crf_decode": crf_decode,
}
# for test or train process # for test or train process
avg_cost, crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=False, target=targets) avg_cost, crf_decode = nets.lex_net(
words, args, vocab_size, num_labels, for_infer=False, target=targets)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks, (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval( num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode, input=crf_decode,
label=targets, label=targets,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=int(math.ceil((num_labels - 1) / 2.0))) num_chunk_types=int(math.ceil((num_labels - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator() chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset() chunk_evaluator.reset()
...@@ -73,7 +79,14 @@ def create_model(args, vocab_size, num_labels, mode='train'): ...@@ -73,7 +79,14 @@ def create_model(args, vocab_size, num_labels, mode='train'):
return ret return ret
def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, return_reader=False, mode='train'): def create_pyreader(args,
file_name,
feed_list,
place,
model='lac',
reader=None,
return_reader=False,
mode='train'):
# init reader # init reader
if model == 'lac': if model == 'lac':
...@@ -81,8 +94,7 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, ...@@ -81,8 +94,7 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None,
feed_list=feed_list, feed_list=feed_list,
capacity=50, capacity=50,
use_double_buffer=True, use_double_buffer=True,
iterable=True iterable=True)
)
if reader == None: if reader == None:
reader = Dataset(args) reader = Dataset(args)
...@@ -93,20 +105,16 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, ...@@ -93,20 +105,16 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None,
fluid.io.batch( fluid.io.batch(
fluid.io.shuffle( fluid.io.shuffle(
reader.file_reader(file_name), reader.file_reader(file_name),
buf_size=args.traindata_shuffle_buffer buf_size=args.traindata_shuffle_buffer),
), batch_size=args.batch_size),
batch_size=args.batch_size places=place)
),
places=place
)
else: else:
pyreader.decorate_sample_list_generator( pyreader.decorate_sample_list_generator(
fluid.io.batch( fluid.io.batch(
reader.file_reader(file_name, mode=mode), reader.file_reader(
batch_size=args.batch_size file_name, mode=mode),
), batch_size=args.batch_size),
places=place places=place)
)
elif model == 'ernie': elif model == 'ernie':
# create ernie pyreader # create ernie pyreader
...@@ -114,8 +122,7 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, ...@@ -114,8 +122,7 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None,
feed_list=feed_list, feed_list=feed_list,
capacity=50, capacity=50,
use_double_buffer=True, use_double_buffer=True,
iterable=True iterable=True)
)
if reader == None: if reader == None:
reader = SequenceLabelReader( reader = SequenceLabelReader(
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
...@@ -127,17 +134,21 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, ...@@ -127,17 +134,21 @@ def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None,
if mode == 'train': if mode == 'train':
pyreader.set_batch_generator( pyreader.set_batch_generator(
reader.data_generator( reader.data_generator(
file_name, args.batch_size, args.epoch, shuffle=True, phase="train" file_name,
), args.batch_size,
places=place args.epoch,
) shuffle=True,
phase="train"),
places=place)
else: else:
pyreader.set_batch_generator( pyreader.set_batch_generator(
reader.data_generator( reader.data_generator(
file_name, args.batch_size, epoch=1, shuffle=False, phase=mode file_name,
), args.batch_size,
places=place epoch=1,
) shuffle=False,
phase=mode),
places=place)
if return_reader: if return_reader:
return pyreader, reader return pyreader, reader
else: else:
...@@ -150,14 +161,20 @@ def create_ernie_model(args, ernie_config): ...@@ -150,14 +161,20 @@ def create_ernie_model(args, ernie_config):
""" """
# ERNIE's input data # ERNIE's input data
src_ids = fluid.data(name='src_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') src_ids = fluid.data(
sent_ids = fluid.data(name='sent_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') name='src_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
pos_ids = fluid.data(name='pos_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') sent_ids = fluid.data(
input_mask = fluid.data(name='input_mask', shape=[-1, args.max_seq_len, 1], dtype='float32') name='sent_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
pos_ids = fluid.data(
name='pos_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
input_mask = fluid.data(
name='input_mask', shape=[-1, args.max_seq_len, 1], dtype='float32')
padded_labels = fluid.data(name='padded_labels', shape=[-1, args.max_seq_len, 1], dtype='int64') padded_labels = fluid.data(
name='padded_labels', shape=[-1, args.max_seq_len, 1], dtype='int64')
seq_lens = fluid.data(name='seq_lens', shape=[-1], dtype='int64', lod_level=0) seq_lens = fluid.data(
name='seq_lens', shape=[-1], dtype='int64', lod_level=0)
squeeze_labels = fluid.layers.squeeze(padded_labels, axes=[-1]) squeeze_labels = fluid.layers.squeeze(padded_labels, axes=[-1])
...@@ -187,28 +204,31 @@ def create_ernie_model(args, ernie_config): ...@@ -187,28 +204,31 @@ def create_ernie_model(args, ernie_config):
input=emission, input=emission,
label=padded_labels, label=padded_labels,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', name='crfw', learning_rate=args.crf_learning_rate),
learning_rate=args.crf_learning_rate),
length=seq_lens) length=seq_lens)
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw'), length=seq_lens) input=emission,
param_attr=fluid.ParamAttr(name='crfw'),
length=seq_lens)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks, (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval( num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode, input=crf_decode,
label=squeeze_labels, label=squeeze_labels,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=int(math.ceil((args.num_labels - 1) / 2.0)), num_chunk_types=int(math.ceil((args.num_labels - 1) / 2.0)),
seq_length=seq_lens) seq_length=seq_lens)
chunk_evaluator = fluid.metrics.ChunkEvaluator() chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset() chunk_evaluator.reset()
ret = { ret = {
"feed_list": [src_ids, sent_ids, pos_ids, input_mask, padded_labels, seq_lens], "feed_list":
[src_ids, sent_ids, pos_ids, input_mask, padded_labels, seq_lens],
"words": src_ids, "words": src_ids,
"labels": padded_labels, "labels": padded_labels,
"seq_lens": seq_lens,
"avg_cost": avg_cost, "avg_cost": avg_cost,
"crf_decode": crf_decode, "crf_decode": crf_decode,
"precision": precision, "precision": precision,
......
...@@ -39,6 +39,7 @@ from models.representation.ernie import ErnieConfig ...@@ -39,6 +39,7 @@ from models.representation.ernie import ErnieConfig
from models.model_check import check_cuda from models.model_check import check_cuda
from models.model_check import check_version from models.model_check import check_version
def evaluate(exe, test_program, test_pyreader, test_ret): def evaluate(exe, test_program, test_pyreader, test_ret):
""" """
Evaluation Function Evaluation Function
...@@ -55,8 +56,7 @@ def evaluate(exe, test_program, test_pyreader, test_ret): ...@@ -55,8 +56,7 @@ def evaluate(exe, test_program, test_pyreader, test_ret):
test_ret["num_label_chunks"], test_ret["num_label_chunks"],
test_ret["num_correct_chunks"], test_ret["num_correct_chunks"],
], ],
feed=data[0] feed=data[0])
)
total_loss.append(loss) total_loss.append(loss)
test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct) test_ret["chunk_evaluator"].update(nums_infer, nums_label, nums_correct)
...@@ -64,9 +64,11 @@ def evaluate(exe, test_program, test_pyreader, test_ret): ...@@ -64,9 +64,11 @@ def evaluate(exe, test_program, test_pyreader, test_ret):
precision, recall, f1 = test_ret["chunk_evaluator"].eval() precision, recall, f1 = test_ret["chunk_evaluator"].eval()
end_time = time.time() end_time = time.time()
print("\t[test] loss: %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s" print(
"\t[test] loss: %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time: %.3f s"
% (np.mean(total_loss), precision, recall, f1, end_time - start_time)) % (np.mean(total_loss), precision, recall, f1, end_time - start_time))
def do_train(args): def do_train(args):
""" """
Main Function Main Function
...@@ -80,14 +82,15 @@ def do_train(args): ...@@ -80,14 +82,15 @@ def do_train(args):
else: else:
dev_count = min(multiprocessing.cpu_count(), args.cpu_num) dev_count = min(multiprocessing.cpu_count(), args.cpu_num)
if (dev_count < args.cpu_num): if (dev_count < args.cpu_num):
print("WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. " print(
"Change the cpu_num from %d to %d"%(dev_count, args.cpu_num, dev_count)) "WARNING: The total CPU NUM in this machine is %d, which is less than cpu_num parameter you set. "
"Change the cpu_num from %d to %d" %
(dev_count, args.cpu_num, dev_count))
os.environ['CPU_NUM'] = str(dev_count) os.environ['CPU_NUM'] = str(dev_count)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
startup_prog = fluid.Program() startup_prog = fluid.Program()
if args.random_seed is not None: if args.random_seed is not None:
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
...@@ -99,49 +102,56 @@ def do_train(args): ...@@ -99,49 +102,56 @@ def do_train(args):
train_ret = creator.create_ernie_model(args, ernie_config) train_ret = creator.create_ernie_model(args, ernie_config)
# ernie pyreader # ernie pyreader
train_pyreader = creator.create_pyreader(args, file_name=args.train_data, train_pyreader = creator.create_pyreader(
feed_list=train_ret['feed_list'], args,
model="ernie", file_name=args.train_data,
place=place) feed_list=train_ret['feed_list'],
model="ernie",
place=place)
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
test_pyreader = creator.create_pyreader(args, file_name=args.test_data, test_pyreader = creator.create_pyreader(
feed_list=train_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place) feed_list=train_ret['feed_list'],
model="ernie",
optimizer = fluid.optimizer.Adam(learning_rate=args.base_learning_rate) place=place)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=args.base_learning_rate)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
optimizer.minimize(train_ret["avg_cost"]) optimizer.minimize(train_ret["avg_cost"])
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
print("Device count: %d" % dev_count) print("Device count: %d" % dev_count)
exe.run(startup_prog) exe.run(startup_prog)
# load checkpoints # load checkpoints
if args.init_checkpoint and args.init_pretraining_params: if args.init_checkpoint and args.init_pretraining_params:
print("WARNING: args 'init_checkpoint' and 'init_pretraining_params' " print("WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
"both are set! Only arg 'init_checkpoint' is made valid.") "both are set! Only arg 'init_checkpoint' is made valid.")
if args.init_checkpoint: if args.init_checkpoint:
utils.init_checkpoint(exe, args.init_checkpoint, startup_prog) utils.init_checkpoint(exe, args.init_checkpoint, startup_prog)
elif args.init_pretraining_params: elif args.init_pretraining_params:
utils.init_pretraining_params(exe, args.init_pretraining_params, startup_prog) utils.init_pretraining_params(exe, args.init_pretraining_params,
startup_prog)
if dev_count>1 and not args.use_cuda: if dev_count > 1 and not args.use_cuda:
device = "GPU" if args.use_cuda else "CPU" device = "GPU" if args.use_cuda else "CPU"
print("%d %s are used to train model"%(dev_count, device)) print("%d %s are used to train model" % (dev_count, device))
# multi cpu/gpu config # multi cpu/gpu config
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
compiled_prog = fluid.compiler.CompiledProgram(train_program).with_data_parallel( compiled_prog = fluid.compiler.CompiledProgram(
loss_name=train_ret['avg_cost'].name, train_program).with_data_parallel(
build_strategy=build_strategy, loss_name=train_ret['avg_cost'].name,
exec_strategy=exec_strategy) build_strategy=build_strategy,
exec_strategy=exec_strategy)
else: else:
compiled_prog = fluid.compiler.CompiledProgram(train_program) compiled_prog = fluid.compiler.CompiledProgram(train_program)
...@@ -162,16 +172,23 @@ def do_train(args): ...@@ -162,16 +172,23 @@ def do_train(args):
start_time = time.time() start_time = time.time()
outputs = exe.run(program=compiled_prog, feed=data[0], fetch_list=fetch_list) outputs = exe.run(program=compiled_prog,
feed=data[0],
fetch_list=fetch_list)
end_time = time.time() end_time = time.time()
if steps % args.print_steps == 0: if steps % args.print_steps == 0:
loss, precision, recall, f1_score = [np.mean(x) for x in outputs] loss, precision, recall, f1_score = [
print("[train] batch_id = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f, " np.mean(x) for x in outputs
"pyreader queue_size: %d " % (steps, loss, precision, recall, f1_score, ]
end_time - start_time, train_pyreader.queue.size())) print(
"[train] batch_id = %d, loss = %.5f, P: %.5f, R: %.5f, F1: %.5f, elapsed time %.5f, "
"pyreader queue_size: %d " %
(steps, loss, precision, recall, f1_score,
end_time - start_time, train_pyreader.queue.size()))
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.model_save_dir, "step_" + str(steps)) save_path = os.path.join(args.model_save_dir,
"step_" + str(steps))
print("\tsaving model as %s" % (save_path)) print("\tsaving model as %s" % (save_path))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
...@@ -182,7 +199,6 @@ def do_train(args): ...@@ -182,7 +199,6 @@ def do_train(args):
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
def do_eval(args): def do_eval(args):
# init executor # init executor
if args.use_cuda: if args.use_cuda:
...@@ -198,11 +214,13 @@ def do_eval(args): ...@@ -198,11 +214,13 @@ def do_eval(args):
test_ret = creator.create_ernie_model(args, ernie_config) test_ret = creator.create_ernie_model(args, ernie_config)
test_program = test_program.clone(for_test=True) test_program = test_program.clone(for_test=True)
pyreader = creator.create_pyreader(args, file_name=args.test_data, pyreader = creator.create_pyreader(
feed_list=test_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place, feed_list=test_ret['feed_list'],
mode='test',) model="ernie",
place=place,
mode='test', )
print('program startup') print('program startup')
...@@ -212,11 +230,13 @@ def do_eval(args): ...@@ -212,11 +230,13 @@ def do_eval(args):
print('program loading') print('program loading')
# load model # load model
if not args.init_checkpoint: if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if only doing test or infer!") raise ValueError(
"args 'init_checkpoint' should be set if only doing test or infer!")
utils.init_checkpoint(exe, args.init_checkpoint, test_program) utils.init_checkpoint(exe, args.init_checkpoint, test_program)
evaluate(exe, test_program, pyreader, test_ret) evaluate(exe, test_program, pyreader, test_ret)
def do_infer(args): def do_infer(args):
# init executor # init executor
if args.use_cuda: if args.use_cuda:
...@@ -233,41 +253,52 @@ def do_infer(args): ...@@ -233,41 +253,52 @@ def do_infer(args):
infer_ret = creator.create_ernie_model(args, ernie_config) infer_ret = creator.create_ernie_model(args, ernie_config)
infer_program = infer_program.clone(for_test=True) infer_program = infer_program.clone(for_test=True)
print(args.test_data) print(args.test_data)
pyreader, reader = creator.create_pyreader(args, file_name=args.test_data, pyreader, reader = creator.create_pyreader(
feed_list=infer_ret['feed_list'], args,
model="ernie", file_name=args.test_data,
place=place, feed_list=infer_ret['feed_list'],
return_reader=True, model="ernie",
mode='test') place=place,
return_reader=True,
mode='test')
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load model # load model
if not args.init_checkpoint: if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if only doing test or infer!") raise ValueError(
"args 'init_checkpoint' should be set if only doing test or infer!")
utils.init_checkpoint(exe, args.init_checkpoint, infer_program) utils.init_checkpoint(exe, args.init_checkpoint, infer_program)
# create dict # create dict
id2word_dict = dict([(str(word_id), word) for word, word_id in reader.vocab.items()]) id2word_dict = dict(
id2label_dict = dict([(str(label_id), label) for label, label_id in reader.label_map.items()]) [(str(word_id), word) for word, word_id in reader.vocab.items()])
id2label_dict = dict([(str(label_id), label)
for label, label_id in reader.label_map.items()])
Dataset = namedtuple("Dataset", ["id2word_dict", "id2label_dict"]) Dataset = namedtuple("Dataset", ["id2word_dict", "id2label_dict"])
dataset = Dataset(id2word_dict, id2label_dict) dataset = Dataset(id2word_dict, id2label_dict)
# make prediction # make prediction
for data in pyreader(): for data in pyreader():
(words, crf_decode) = exe.run(infer_program, (words, crf_decode, seq_lens) = exe.run(infer_program,
fetch_list=[infer_ret["words"], infer_ret["crf_decode"]], fetch_list=[
feed=data[0], infer_ret["words"],
return_numpy=False) infer_ret["crf_decode"],
infer_ret["seq_lens"]
],
feed=data[0],
return_numpy=True)
# User should notice that words had been clipped if long than args.max_seq_len # User should notice that words had been clipped if long than args.max_seq_len
results = utils.parse_result(words, crf_decode, dataset) results = utils.parse_padding_result(words, crf_decode, seq_lens,
dataset)
for sent, tags in results: for sent, tags in results:
result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)] result_list = [
'(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)
]
print(''.join(result_list)) print(''.join(result_list))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
utils.load_yaml(parser, './conf/ernie_args.yaml') utils.load_yaml(parser, './conf/ernie_args.yaml')
...@@ -284,4 +315,3 @@ if __name__ == "__main__": ...@@ -284,4 +315,3 @@ if __name__ == "__main__":
do_infer(args) do_infer(args)
else: else:
print("Usage: %s --mode train|eval|infer " % sys.argv[0]) print("Usage: %s --mode train|eval|infer " % sys.argv[0])
...@@ -148,6 +148,50 @@ def parse_result(words, crf_decode, dataset): ...@@ -148,6 +148,50 @@ def parse_result(words, crf_decode, dataset):
return batch_out return batch_out
def parse_padding_result(words, crf_decode, seq_lens, dataset):
""" parse padding result """
words = np.squeeze(words)
batch_size = len(seq_lens)
batch_out = []
for sent_index in range(batch_size):
sent = [
dataset.id2word_dict[str(id)]
for id in words[sent_index][1:seq_lens[sent_index] - 1]
]
tags = [
dataset.id2label_dict[str(id)]
for id in crf_decode[sent_index][1:seq_lens[sent_index] - 1]
]
sent_out = []
tags_out = []
parital_word = ""
for ind, tag in enumerate(tags):
# for the first word
if parital_word == "":
parital_word = sent[ind]
tags_out.append(tag.split('-')[0])
continue
# for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word)
tags_out.append(tag.split('-')[0])
parital_word = sent[ind]
continue
parital_word += sent[ind]
# append the last word, except for len(tags)=0
if len(sent_out) < len(tags_out):
sent_out.append(parital_word)
batch_out.append([sent_out, tags_out])
return batch_out
def init_checkpoint(exe, init_checkpoint_path, main_program): def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
Init CheckPoint Init CheckPoint
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册