diff --git a/ERNIE/batching.py b/ERNIE/batching.py index 065b11b6e6874abc10b6c6f7728f5c8e87c41c47..c3130a3bbe14ae31fbbf08ff8fa005b61ff305ba 100644 --- a/ERNIE/batching.py +++ b/ERNIE/batching.py @@ -166,7 +166,8 @@ def pad_batch_data(insts, return_pos=False, return_input_mask=False, return_max_len=False, - return_num_token=False): + return_num_token=False, + return_seq_lens=False): """ Pad the instances to the max sequence length in batch, and generate the corresponding position data and attention bias. @@ -205,6 +206,10 @@ def pad_batch_data(insts, num_token += len(inst) return_list += [num_token] + if return_seq_lens: + seq_lens = np.array([len(inst) for inst in insts]) + return_list += [seq_lens.astype("int64").reshape([-1, 1])] + return return_list if len(return_list) > 1 else return_list[0] diff --git a/ERNIE/ernir_encoder.py b/ERNIE/ernir_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae34cef6dde1990a4618096a38280381d3f7b53 --- /dev/null +++ b/ERNIE/ernir_encoder.py @@ -0,0 +1,192 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""extract embeddings from ERNIE encoder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import argparse +import numpy as np +import multiprocessing + +import paddle.fluid as fluid + +import reader.task_reader as task_reader +from model.ernie import ErnieConfig, ErnieModel +from utils.args import ArgumentGroup, print_arguments +from utils.init import init_pretraining_params + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +model_g = ArgumentGroup(parser, "model", "model configuration and paths.") +model_g.add_arg("ernie_config_path", str, None, "Path to the json file for ernie model config.") +model_g.add_arg("init_pretraining_params", str, None, + "Init pre-training params which preforms fine-tuning from. If the " + "arg 'init_checkpoint' has been set, this argument wouldn't be valid.") +model_g.add_arg("output_dir", str, "embeddings", "path to save embeddings extracted by ernie_encoder.") + +data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") +data_g.add_arg("data_set", str, None, "Path to data for calculating ernie_embeddings.") +data_g.add_arg("vocab_path", str, None, "Vocabulary path.") +data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") +data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.") +data_g.add_arg("do_lower_case", bool, True, + "Whether to lower case the input text. Should be True for uncased models and False for cased models.") +data_g.add_arg("random_seed", int, 0, "Random seed.") + +run_type_g = ArgumentGroup(parser, "run_type", "running type options.") +run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") +run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).") +run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.") +# yapf: enable + + +def create_model(args, pyreader_name, ernie_config, is_prediction=False): + pyreader = fluid.layers.py_reader( + capacity=50, + shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], + [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1]], + dtypes=['int64', 'int64', 'int64', 'float', 'int64'], + lod_levels=[0, 0, 0, 0, 0], + name=pyreader_name, + use_double_buffer=True) + + (src_ids, sent_ids, pos_ids, input_mask, + seq_lens) = fluid.layers.read_file(pyreader) + + ernie = ErnieModel( + src_ids=src_ids, + position_ids=pos_ids, + sentence_ids=sent_ids, + input_mask=input_mask, + config=ernie_config) + + enc_out = ernie.get_sequence_output() + unpad_enc_out = fluid.layers.sequence_unpad(enc_out, length=seq_lens) + cls_feats = ernie.get_pooled_output() + + # set persistable = True to avoid memory opimizing + enc_out.persistable = True + unpad_enc_out.persistable = True + cls_feats.persistable = True + + graph_vars = { + "cls_embeddings": cls_feats, + "top_layer_embeddings": unpad_enc_out, + } + + return pyreader, graph_vars + + +def main(args): + args = parser.parse_args() + ernie_config = ErnieConfig(args.ernie_config_path) + ernie_config.print_config() + + if args.use_cuda: + place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) + dev_count = fluid.core.get_cuda_device_count() + else: + place = fluid.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + exe = fluid.Executor(place) + + reader = task_reader.ExtractEmbeddingReader( + vocab_path=args.vocab_path, + max_seq_len=args.max_seq_len, + do_lower_case=args.do_lower_case, + random_seed=args.random_seed) + + startup_prog = fluid.Program() + if args.random_seed is not None: + startup_prog.random_seed = args.random_seed + + data_generator = reader.data_generator( + input_file=args.data_set, + batch_size=args.batch_size, + epoch=1, + shuffle=False, + phase="train") + + total_examples = reader.get_num_examples(args.data_set) + + print("Device count: %d" % dev_count) + print("Total num examples: %d" % total_examples) + + train_program = fluid.Program() + + with fluid.program_guard(train_program, startup_prog): + with fluid.unique_name.guard(): + pyreader, graph_vars = create_model( + args, pyreader_name='reader', ernie_config=ernie_config) + + fluid.memory_optimize(input_program=train_program) + + train_program = train_program.clone(for_test=True) + + exe.run(startup_prog) + + if args.init_pretraining_params: + init_pretraining_params( + exe, args.init_pretraining_params, main_program=startup_prog) + else: + raise ValueError( + "WARNING: args 'init_pretraining_params' must be specified") + + exec_strategy = fluid.ExecutionStrategy() + if args.use_fast_executor: + exec_strategy.use_experimental_executor = True + exec_strategy.num_threads = dev_count + exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope + + pyreader.decorate_tensor_provider(data_generator) + pyreader.start() + + total_cls_emb = [] + total_top_layer_emb = [] + total_labels = [] + while True: + try: + cls_emb, unpad_top_layer_emb = exe.run( + program=train_program, + fetch_list=[ + graph_vars["cls_embeddings"].name, graph_vars[ + "top_layer_embeddings"].name + ], + return_numpy=False) + # batch_size * embedding_size + total_cls_emb.append(np.array(cls_emb)) + total_top_layer_emb.append(np.array(unpad_top_layer_emb)) + except fluid.core.EOFException: + break + + total_cls_emb = np.concatenate(total_cls_emb) + total_top_layer_emb = np.concatenate(total_top_layer_emb) + + with open(os.path.join(args.output_dir, "cls_emb.npy"), + "w") as cls_emb_file: + np.save(cls_emb_file, total_cls_emb) + with open(os.path.join(args.output_dir, "top_layer_emb.npy"), + "w") as top_layer_emb_file: + np.save(top_layer_emb_file, total_top_layer_emb) + + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + + main(args) diff --git a/ERNIE/finetune/classifier.py b/ERNIE/finetune/classifier.py index 8a69c3d5dfe65a9b8d2f75a1820ddf0ee1f4a476..48139c8d6da84b8bf86f7a361eefd6df6987bc29 100644 --- a/ERNIE/finetune/classifier.py +++ b/ERNIE/finetune/classifier.py @@ -20,9 +20,8 @@ from __future__ import print_function import time import numpy as np -import paddle.fluid as fluid - from six.moves import xrange +import paddle.fluid as fluid from model.ernie import ErnieModel diff --git a/ERNIE/model/ernie.py b/ERNIE/model/ernie.py index 3ccfb72a43b8385dc4f5e92ec2a446d40c384788..f2b5c0faaec150cf619e0b9ff45d103790e2539d 100644 --- a/ERNIE/model/ernie.py +++ b/ERNIE/model/ernie.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import json -import numpy as np + +import six import paddle.fluid as fluid + from model.transformer_encoder import encoder, pre_process_layer diff --git a/ERNIE/model/transformer_encoder.py b/ERNIE/model/transformer_encoder.py index 93a77ebe480f0e4a8e2b4f2c0c18b23383075fb7..ac5d293f3b198e8529afa75940c5aaf0a9fdbfc4 100644 --- a/ERNIE/model/transformer_encoder.py +++ b/ERNIE/model/transformer_encoder.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from functools import partial -import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers diff --git a/ERNIE/reader/task_reader.py b/ERNIE/reader/task_reader.py index 664e6d635a3ab4eeba04f504829099d34fde6d71..c00bdb51f575395c5f437ec8f9d815be2d60ad52 100644 --- a/ERNIE/reader/task_reader.py +++ b/ERNIE/reader/task_reader.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import csv import json import numpy as np @@ -50,7 +49,6 @@ class BaseReader(object): self.label_map = json.load(f) else: self.label_map = None - pass def get_train_progress(self): """Gets progress for training phase.""" @@ -183,7 +181,7 @@ class BaseReader(object): yield self._pad_batch_records(batch_records) batch_records, max_len = [record], len(record.token_ids) - if len(batch_records) > 0: + if batch_records: yield self._pad_batch_records(batch_records) def get_num_examples(self, input_file): @@ -268,19 +266,19 @@ class SequenceLabelReader(BaseReader): batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records] batch_label_ids = [record.label_ids for record in batch_records] - batch_seq_lens = [len(record.token_ids) for record in batch_records] # padding - padded_token_ids, input_mask = pad_batch_data( - batch_token_ids, pad_idx=self.pad_id, return_input_mask=True) + padded_token_ids, input_mask, batch_seq_lens = pad_batch_data( + batch_token_ids, + pad_idx=self.pad_id, + return_input_mask=True, + return_seq_lens=True) padded_text_type_ids = pad_batch_data( batch_text_type_ids, pad_idx=self.pad_id) padded_position_ids = pad_batch_data( batch_position_ids, pad_idx=self.pad_id) padded_label_ids = pad_batch_data( batch_label_ids, pad_idx=len(self.label_map) - 1) - batch_seq_lens = np.array(batch_seq_lens).astype("int64").reshape( - [-1, 1]) return_list = [ padded_token_ids, padded_text_type_ids, padded_position_ids, @@ -337,5 +335,30 @@ class SequenceLabelReader(BaseReader): return record +class ExtractEmbeddingReader(BaseReader): + def _pad_batch_records(self, batch_records): + batch_token_ids = [record.token_ids for record in batch_records] + batch_text_type_ids = [record.text_type_ids for record in batch_records] + batch_position_ids = [record.position_ids for record in batch_records] + + # padding + padded_token_ids, input_mask, seq_lens = pad_batch_data( + batch_token_ids, + pad_idx=self.pad_id, + return_input_mask=True, + return_seq_lens=True) + padded_text_type_ids = pad_batch_data( + batch_text_type_ids, pad_idx=self.pad_id) + padded_position_ids = pad_batch_data( + batch_position_ids, pad_idx=self.pad_id) + + return_list = [ + padded_token_ids, padded_text_type_ids, padded_position_ids, + input_mask, seq_lens + ] + + return return_list + + if __name__ == '__main__': pass diff --git a/ERNIE/run_classifier.py b/ERNIE/run_classifier.py index df024902887fb1cf4a1ad03ba83a29de18431e3e..409a690f04943ea356509cbba2b2c8af66ec0b43 100644 --- a/ERNIE/run_classifier.py +++ b/ERNIE/run_classifier.py @@ -19,18 +19,15 @@ from __future__ import print_function import os import time -import argparse -import numpy as np import multiprocessing -import paddle import paddle.fluid as fluid import reader.task_reader as task_reader from model.ernie import ErnieConfig from finetune.classifier import create_model, evaluate from optimization import optimization -from utils.args import ArgumentGroup, print_arguments +from utils.args import print_arguments from utils.init import init_pretraining_params, init_checkpoint from finetune_args import parser @@ -184,12 +181,6 @@ def main(args): else: train_exe = None - if args.do_val or args.do_test: - test_exe = fluid.ParallelExecutor( - use_cuda=args.use_cuda, - main_program=test_prog, - share_vars_from=train_exe) - if args.do_train: train_pyreader.start() steps = 0 @@ -238,7 +229,8 @@ def main(args): batch_size=args.batch_size, epoch=1, shuffle=False)) - evaluate(exe, test_prog, test_pyreader, graph_vars, "dev") + evaluate(exe, test_prog, test_pyreader, graph_vars, + "dev") # evaluate test set if args.do_test: test_pyreader.decorate_tensor_provider( @@ -247,7 +239,8 @@ def main(args): batch_size=args.batch_size, epoch=1, shuffle=False)) - evaluate(exe, test_prog, test_pyreader, graph_vars, "test") + evaluate(exe, test_prog, test_pyreader, graph_vars, + "test") except fluid.core.EOFException: save_path = os.path.join(args.checkpoints, "step_" + str(steps)) fluid.io.save_persistables(exe, save_path, train_program) diff --git a/ERNIE/run_sequence_labeling.py b/ERNIE/run_sequence_labeling.py index 1ee7544293e3bf920b21229a90b1341966617f18..feffb8fe915bce399ab038b8add43b30235430b3 100644 --- a/ERNIE/run_sequence_labeling.py +++ b/ERNIE/run_sequence_labeling.py @@ -19,10 +19,8 @@ from __future__ import print_function import os import time -import numpy as np import multiprocessing -import paddle import paddle.fluid as fluid import reader.task_reader as task_reader @@ -264,7 +262,8 @@ def main(args): epoch=1, shuffle=False)) print("Final validation result:") - evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, "dev") + evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, + "dev") # final eval on test set if args.do_test: @@ -275,7 +274,8 @@ def main(args): epoch=1, shuffle=False)) print("Final test result:") - evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, "test") + evaluate(exe, test_prog, test_pyreader, graph_vars, args.num_labels, + "test") if __name__ == '__main__': diff --git a/ERNIE/train.py b/ERNIE/train.py index 665ab85e6114b9c9e7d6e8b9a5212dfb32c66661..4faec96ba07aa6931ba5c96a0cacc53209058924 100644 --- a/ERNIE/train.py +++ b/ERNIE/train.py @@ -19,17 +19,15 @@ from __future__ import print_function import os import time -import argparse -import numpy as np import multiprocessing -import paddle +import numpy as np import paddle.fluid as fluid from reader.pretraining import ErnieDataReader from model.ernie import ErnieModel, ErnieConfig from optimization import optimization -from utils.args import ArgumentGroup, print_arguments +from utils.args import print_arguments from utils.init import init_checkpoint, init_pretraining_params from pretrain_args import parser