提交 dcfa4624 编写于 作者: T tianxin04

1. add ernie encoder for extracing sentence/tokens embeddings based on ERNIE

 2. return seq_lens by pad_batch_data function
上级 fbeac9b3
......@@ -166,7 +166,8 @@ def pad_batch_data(insts,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
......@@ -205,6 +206,10 @@ def pad_batch_data(insts,
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1, 1])]
return return_list if len(return_list) > 1 else return_list[0]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""extract embeddings from ERNIE encoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
import multiprocessing
import paddle.fluid as fluid
import reader.task_reader as task_reader
from model.ernie import ErnieConfig, ErnieModel
from utils.args import ArgumentGroup, print_arguments
from utils.init import init_pretraining_params
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for ernie model config.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("output_dir", str, "embeddings", "path to save embeddings extracted by ernie_encoder.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_set", str, None, "Path to data for calculating ernie_embeddings.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
# yapf: enable
def create_model(args, pyreader_name, ernie_config, is_prediction=False):
pyreader = fluid.layers.py_reader(
capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1]],
dtypes=['int64', 'int64', 'int64', 'float', 'int64'],
lod_levels=[0, 0, 0, 0, 0],
name=pyreader_name,
use_double_buffer=True)
(src_ids, sent_ids, pos_ids, input_mask,
seq_lens) = fluid.layers.read_file(pyreader)
ernie = ErnieModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
input_mask=input_mask,
config=ernie_config)
enc_out = ernie.get_sequence_output()
unpad_enc_out = fluid.layers.sequence_unpad(enc_out, length=seq_lens)
cls_feats = ernie.get_pooled_output()
# set persistable = True to avoid memory opimizing
enc_out.persistable = True
unpad_enc_out.persistable = True
cls_feats.persistable = True
graph_vars = {
"cls_embeddings": cls_feats,
"top_layer_embeddings": unpad_enc_out,
}
return pyreader, graph_vars
def main(args):
args = parser.parse_args()
ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config()
if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
reader = task_reader.ExtractEmbeddingReader(
vocab_path=args.vocab_path,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
random_seed=args.random_seed)
startup_prog = fluid.Program()
if args.random_seed is not None:
startup_prog.random_seed = args.random_seed
data_generator = reader.data_generator(
input_file=args.data_set,
batch_size=args.batch_size,
epoch=1,
shuffle=False,
phase="train")
total_examples = reader.get_num_examples(args.data_set)
print("Device count: %d" % dev_count)
print("Total num examples: %d" % total_examples)
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard():
pyreader, graph_vars = create_model(
args, pyreader_name='reader', ernie_config=ernie_config)
fluid.memory_optimize(input_program=train_program)
train_program = train_program.clone(for_test=True)
exe.run(startup_prog)
if args.init_pretraining_params:
init_pretraining_params(
exe, args.init_pretraining_params, main_program=startup_prog)
else:
raise ValueError(
"WARNING: args 'init_pretraining_params' must be specified")
exec_strategy = fluid.ExecutionStrategy()
if args.use_fast_executor:
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = dev_count
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
pyreader.decorate_tensor_provider(data_generator)
pyreader.start()
total_cls_emb = []
total_top_layer_emb = []
total_labels = []
while True:
try:
cls_emb, unpad_top_layer_emb = exe.run(
program=train_program,
fetch_list=[
graph_vars["cls_embeddings"].name, graph_vars[
"top_layer_embeddings"].name
],
return_numpy=False)
# batch_size * embedding_size
total_cls_emb.append(np.array(cls_emb))
total_top_layer_emb.append(np.array(unpad_top_layer_emb))
except fluid.core.EOFException:
break
total_cls_emb = np.concatenate(total_cls_emb)
total_top_layer_emb = np.concatenate(total_top_layer_emb)
with open(os.path.join(args.output_dir, "cls_emb.npy"),
"w") as cls_emb_file:
np.save(cls_emb_file, total_cls_emb)
with open(os.path.join(args.output_dir, "top_layer_emb.npy"),
"w") as top_layer_emb_file:
np.save(top_layer_emb_file, total_top_layer_emb)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
main(args)
......@@ -20,9 +20,8 @@ from __future__ import print_function
import time
import numpy as np
import paddle.fluid as fluid
from six.moves import xrange
import paddle.fluid as fluid
from model.ernie import ErnieModel
......
......@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import json
import numpy as np
import six
import paddle.fluid as fluid
from model.transformer_encoder import encoder, pre_process_layer
......
......@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import csv
import json
import numpy as np
......@@ -50,7 +49,6 @@ class BaseReader(object):
self.label_map = json.load(f)
else:
self.label_map = None
pass
def get_train_progress(self):
"""Gets progress for training phase."""
......@@ -183,7 +181,7 @@ class BaseReader(object):
yield self._pad_batch_records(batch_records)
batch_records, max_len = [record], len(record.token_ids)
if len(batch_records) > 0:
if batch_records:
yield self._pad_batch_records(batch_records)
def get_num_examples(self, input_file):
......@@ -268,19 +266,19 @@ class SequenceLabelReader(BaseReader):
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
batch_seq_lens = [len(record.token_ids) for record in batch_records]
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids, pad_idx=self.pad_id, return_input_mask=True)
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids, pad_idx=len(self.label_map) - 1)
batch_seq_lens = np.array(batch_seq_lens).astype("int64").reshape(
[-1, 1])
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
......@@ -337,5 +335,30 @@ class SequenceLabelReader(BaseReader):
return record
class ExtractEmbeddingReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
# padding
padded_token_ids, input_mask, seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, seq_lens
]
return return_list
if __name__ == '__main__':
pass
......@@ -19,18 +19,15 @@ from __future__ import print_function
import os
import time
import argparse
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from finetune.classifier import create_model, evaluate
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments
from utils.args import print_arguments
from utils.init import init_pretraining_params, init_checkpoint
from finetune_args import parser
......@@ -184,12 +181,6 @@ def main(args):
else:
train_exe = None
if args.do_val or args.do_test:
test_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
main_program=test_prog,
share_vars_from=train_exe)
if args.do_train:
train_pyreader.start()
steps = 0
......@@ -238,7 +229,8 @@ def main(args):
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars, "dev")
evaluate(exe, test_prog, test_pyreader, graph_vars,
"dev")
# evaluate test set
if args.do_test:
test_pyreader.decorate_tensor_provider(
......@@ -247,7 +239,8 @@ def main(args):
batch_size=args.batch_size,
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, graph_vars, "test")
evaluate(exe, test_prog, test_pyreader, graph_vars,
"test")
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
......
......@@ -19,10 +19,8 @@ from __future__ import print_function
import os
import time
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
import reader.task_reader as task_reader
......@@ -264,7 +262,8 @@ def main(args):
epoch=1,
shuffle=False))
print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, "dev")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels,
"dev")
# final eval on test set
if args.do_test:
......@@ -275,7 +274,8 @@ def main(args):
epoch=1,
shuffle=False))
print("Final test result:")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, "test")
evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels,
"test")
if __name__ == '__main__':
......
......@@ -19,17 +19,15 @@ from __future__ import print_function
import os
import time
import argparse
import numpy as np
import multiprocessing
import paddle
import numpy as np
import paddle.fluid as fluid
from reader.pretraining import ErnieDataReader
from model.ernie import ErnieModel, ErnieConfig
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments
from utils.args import print_arguments
from utils.init import init_checkpoint, init_pretraining_params
from pretrain_args import parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册