diff --git a/BERT/model/bert.py b/BERT/model/bert.py index e698dda2979be539fcb30e6618fa0e4ff6e7e92a..c17803caed17e81fafd55f9b9ae9f2b539f9f39c 100644 --- a/BERT/model/bert.py +++ b/BERT/model/bert.py @@ -115,7 +115,7 @@ class BertModel(object): self_attn_mask = fluid.layers.matmul( x=input_mask, y=input_mask, transpose_y=True) self_attn_mask = fluid.layers.scale( - x=self_attn_mask, scale=1000.0, bias=-1.0, bias_after_scale=False) + x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False) n_head_self_attn_mask = fluid.layers.stack( x=[self_attn_mask] * self._n_head, axis=1) n_head_self_attn_mask.stop_gradient = True diff --git a/ERNIE/README.md b/ERNIE/README.md index cda2a9d2a6528f84a9932012f7806724d7c8a52d..35f85c63f54a4655c3c22d892bf01d69fb137d0e 100644 --- a/ERNIE/README.md +++ b/ERNIE/README.md @@ -166,7 +166,7 @@ nlpcc-dbqa是由国际自然语言处理和中文计算会议NLPCC于2016年举 2) [任务数据下载](https://ernie.bj.bcebos.com/task_data.tgz) ### 安装 -本项目依赖于 Paddle Fluid 1.3.0,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。 +本项目依赖于 Paddle Fluid 1.3.1,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。 **Note**: 预训练任务和finetune任务测试机器为P40, 显存22G;如果显存低于22G, 某些任务可能会因显存不足报错; diff --git a/ERNIE/batching.py b/ERNIE/batching.py index 797ab9f5938324df07ee3506870fc4cc21d6e75d..065b11b6e6874abc10b6c6f7728f5c8e87c41c47 100644 --- a/ERNIE/batching.py +++ b/ERNIE/batching.py @@ -124,7 +124,7 @@ def prepare_batch_data(insts, cls_id=None, sep_id=None, mask_id=None, - return_attn_bias=True, + return_input_mask=True, return_max_len=True, return_num_token=False): @@ -149,14 +149,13 @@ def prepare_batch_data(insts, MASK=mask_id) # Second step: padding - src_id, next_sent_index, self_attn_bias = pad_batch_data( - out, pad_idx=pad_id, return_next_sent_pos=True, return_attn_bias=True) + src_id, self_input_mask = pad_batch_data( + out, pad_idx=pad_id, return_input_mask=True) pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id) sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id) return_list = [ - src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels, - next_sent_index + src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos, labels ] return return_list @@ -165,8 +164,7 @@ def prepare_batch_data(insts, def pad_batch_data(insts, pad_idx=0, return_pos=False, - return_next_sent_pos=False, - return_attn_bias=False, + return_input_mask=False, return_max_len=False, return_num_token=False): """ @@ -182,15 +180,6 @@ def pad_batch_data(insts, [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] - # next_sent_pos for extract first token embedding of each sentence - if return_next_sent_pos: - batch_size = inst_data.shape[0] - max_seq_len = inst_data.shape[1] - next_sent_index = np.array( - range(0, batch_size * max_seq_len, max_seq_len)).astype( - "int64").reshape(-1, 1) - return_list += [next_sent_index] - # position data if return_pos: inst_pos = np.array([ @@ -200,13 +189,12 @@ def pad_batch_data(insts, return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] - if return_attn_bias: + if return_input_mask: # This is used to avoid attention on paddings. - slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * - (max_len - len(inst)) for inst in insts]) - slf_attn_bias_data = np.tile( - slf_attn_bias_data.reshape([-1, 1, max_len]), [1, max_len, 1]) - return_list += [slf_attn_bias_data.astype("float32")] + input_mask_data = np.array([[1] * len(inst) + [0] * + (max_len - len(inst)) for inst in insts]) + input_mask_data = np.expand_dims(input_mask_data, axis=-1) + return_list += [input_mask_data.astype("float32")] if return_max_len: return_list += [max_len] diff --git a/ERNIE/finetune/classifier.py b/ERNIE/finetune/classifier.py index 37415fb97f9b3e425d60ff414fe548d83c5f13b1..8a69c3d5dfe65a9b8d2f75a1820ddf0ee1f4a476 100644 --- a/ERNIE/finetune/classifier.py +++ b/ERNIE/finetune/classifier.py @@ -31,26 +31,25 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): pyreader = fluid.layers.py_reader( capacity=50, shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], - [-1, args.max_seq_len, 1], - [-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1], + [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1], [-1, 1]], - dtypes=['int64', 'int64', 'int64', 'float', 'int64', 'int64', 'int64'], - lod_levels=[0, 0, 0, 0, 0, 0, 0], + dtypes=['int64', 'int64', 'int64', 'float32', 'int64', 'int64'], + lod_levels=[0, 0, 0, 0, 0, 0], name=pyreader_name, use_double_buffer=True) - (src_ids, sent_ids, pos_ids, self_attn_mask, labels, next_sent_index, + (src_ids, sent_ids, pos_ids, input_mask, labels, qids) = fluid.layers.read_file(pyreader) ernie = ErnieModel( src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, - self_attn_mask=self_attn_mask, + input_mask=input_mask, config=ernie_config, use_fp16=args.use_fp16) - cls_feats = ernie.get_pooled_output(next_sent_index) + cls_feats = ernie.get_pooled_output() cls_feats = fluid.layers.dropout( x=cls_feats, dropout_prob=0.1, @@ -67,8 +66,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): if is_prediction: probs = fluid.layers.softmax(logits) feed_targets_name = [ - src_ids.name, pos_ids.name, sent_ids.name, self_attn_mask.name, - next_sent_index.name + src_ids.name, pos_ids.name, sent_ids.name, input_mask.name ] return pyreader, probs, feed_targets_name diff --git a/ERNIE/finetune/sequence_label.py b/ERNIE/finetune/sequence_label.py index dab6a58c85587381c27a53a94306d360fc405ac6..3c5163ecd56af57334e9f7f49acc259309cc8994 100644 --- a/ERNIE/finetune/sequence_label.py +++ b/ERNIE/finetune/sequence_label.py @@ -29,28 +29,26 @@ from six.moves import xrange from model.ernie import ErnieModel -def create_model(args, - pyreader_name, - ernie_config, - is_prediction=False): + +def create_model(args, pyreader_name, ernie_config, is_prediction=False): pyreader = fluid.layers.py_reader( capacity=50, shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], - [-1, args.max_seq_len, 1], [-1, args.max_seq_len, args.max_seq_len], + [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1]], - dtypes=['int64', 'int64', 'int64', 'float', 'int64', 'int64'], + dtypes=['int64', 'int64', 'int64', 'float32', 'int64', 'int64'], lod_levels=[0, 0, 0, 0, 0, 0], name=pyreader_name, use_double_buffer=True) - (src_ids, sent_ids, pos_ids, self_attn_mask, labels, + (src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens) = fluid.layers.read_file(pyreader) ernie = ErnieModel( src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, - self_attn_mask=self_attn_mask, + input_mask=input_mask, config=ernie_config, use_fp16=args.use_fp16) @@ -63,33 +61,40 @@ def create_model(args, name="cls_seq_label_out_w", initializer=fluid.initializer.TruncatedNormal(scale=0.02)), bias_attr=fluid.ParamAttr( - name="cls_seq_label_out_b", initializer=fluid.initializer.Constant(0.))) + name="cls_seq_label_out_b", + initializer=fluid.initializer.Constant(0.))) - ret_labels = fluid.layers.reshape(x=labels, shape=[-1,1]) - ret_infers = fluid.layers.reshape(x=fluid.layers.argmax(logits, axis=2), shape=[-1,1]) + ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1]) + ret_infers = fluid.layers.reshape( + x=fluid.layers.argmax( + logits, axis=2), shape=[-1, 1]) labels = fluid.layers.flatten(labels, axis=2) ce_loss, probs = fluid.layers.softmax_with_cross_entropy( - logits=fluid.layers.flatten(logits, axis=2), - label=labels, return_softmax=True) + logits=fluid.layers.flatten( + logits, axis=2), + label=labels, + return_softmax=True) loss = fluid.layers.mean(x=ce_loss) if args.use_fp16 and args.loss_scaling > 1.0: loss *= args.loss_scaling - graph_vars = {"loss": loss, - "probs": probs, - "labels": ret_labels, - "infers": ret_infers, - "seq_lens": seq_lens} + graph_vars = { + "loss": loss, + "probs": probs, + "labels": ret_labels, + "infers": ret_infers, + "seq_lens": seq_lens + } for k, v in graph_vars.items(): - v.persistable=True + v.persistable = True return pyreader, graph_vars -def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): +def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): def extract_bio_chunk(seq): chunks = [] cur_chunk = None @@ -109,18 +114,18 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): if cur_chunk is not None: chunks.append(cur_chunk) cur_chunk = {} - cur_chunk = {"st":index, "en": index + 1, "type": tag_type} + cur_chunk = {"st": index, "en": index + 1, "type": tag_type} else: if cur_chunk is None: - cur_chunk = {"st":index, "en": index + 1, "type": tag_type} + cur_chunk = {"st": index, "en": index + 1, "type": tag_type} continue if cur_chunk["type"] == tag_type: - cur_chunk["en"] = index + 1 + cur_chunk["en"] = index + 1 else: chunks.append(cur_chunk) - cur_chunk = {"st":index, "en": index + 1, "type": tag_type} + cur_chunk = {"st": index, "en": index + 1, "type": tag_type} if cur_chunk is not None: chunks.append(cur_chunk) @@ -151,14 +156,19 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): infer_index = 0 label_index = 0 - while label_index < len(label_chunks) and infer_index < len(infer_chunks): - if infer_chunks[infer_index]["st"] < label_chunks[label_index]["st"]: + while label_index < len(label_chunks) \ + and infer_index < len(infer_chunks): + if infer_chunks[infer_index]["st"] \ + < label_chunks[label_index]["st"]: infer_index += 1 - elif infer_chunks[infer_index]["st"] > label_chunks[label_index]["st"]: + elif infer_chunks[infer_index]["st"] \ + > label_chunks[label_index]["st"]: label_index += 1 else: - if infer_chunks[infer_index]["en"] == label_chunks[label_index]["en"] and \ - infer_chunks[infer_index]["type"] == label_chunks[label_index]["type"]: + if infer_chunks[infer_index]["en"] \ + == label_chunks[label_index]["en"] \ + and infer_chunks[infer_index]["type"] \ + == label_chunks[label_index]["type"]: num_correct += 1 infer_index += 1 @@ -168,6 +178,7 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): return num_label, num_infer, num_correct + def calculate_f1(num_label, num_infer, num_correct): if num_infer == 0: precision = 0.0 @@ -185,10 +196,18 @@ def calculate_f1(num_label, num_infer, num_correct): f1 = 2 * precision * recall / (precision + recall) return precision, recall, f1 -def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count=1): - fetch_list = [graph_vars["labels"].name, - graph_vars["infers"].name, - graph_vars["seq_lens"].name] + +def evaluate(exe, + program, + pyreader, + graph_vars, + tag_num, + eval_phase, + dev_count=1): + fetch_list = [ + graph_vars["labels"].name, graph_vars["infers"].name, + graph_vars["seq_lens"].name + ] if eval_phase == "train": fetch_list.append(graph_vars["loss"].name) @@ -196,9 +215,15 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= fetch_list.append(graph_vars["learning_rate"].name) outputs = exe.run(fetch_list=fetch_list) np_labels, np_infers, np_lens, np_loss = outputs[:4] - num_label, num_infer, num_correct = chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count) + num_label, num_infer, num_correct = chunk_eval( + np_labels, np_infers, np_lens, tag_num, dev_count) precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct) - outputs = {"precision": precision, "recall": recall, "f1": f1, "loss": np.mean(np_loss)} + outputs = { + "precision": precision, + "recall": recall, + "f1": f1, + "loss": np.mean(np_loss) + } if "learning_rate" in graph_vars: outputs["lr"] = float(outputs[4][0]) return outputs @@ -209,8 +234,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= pyreader.start() while True: try: - np_labels, np_infers, np_lens = exe.run(program=program, fetch_list=fetch_list) - label_num, infer_num, correct_num = chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count) + np_labels, np_infers, np_lens = exe.run(program=program, + fetch_list=fetch_list) + label_num, infer_num, correct_num = chunk_eval( + np_labels, np_infers, np_lens, tag_num, dev_count) total_infer += infer_num total_label += label_num total_correct += correct_num @@ -219,8 +246,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= pyreader.reset() break - precision, recall, f1 = calculate_f1(total_label, total_infer, total_correct) + precision, recall, f1 = calculate_f1(total_label, total_infer, + total_correct) time_end = time.time() - print("[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s" % - (eval_phase, f1, precision, recall, time_end - time_begin)) + print( + "[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s" + % (eval_phase, f1, precision, recall, time_end - time_begin)) diff --git a/ERNIE/model/ernie.py b/ERNIE/model/ernie.py index e42b2f4558097f90ed1732dc0a7dd29b3a011aee..3ccfb72a43b8385dc4f5e92ec2a446d40c384788 100644 --- a/ERNIE/model/ernie.py +++ b/ERNIE/model/ernie.py @@ -52,7 +52,7 @@ class ErnieModel(object): src_ids, position_ids, sentence_ids, - self_attn_mask, + input_mask, config, weight_sharing=True, use_fp16=False): @@ -78,9 +78,9 @@ class ErnieModel(object): self._param_initializer = fluid.initializer.TruncatedNormal( scale=config['initializer_range']) - self._build_model(src_ids, position_ids, sentence_ids, self_attn_mask) + self._build_model(src_ids, position_ids, sentence_ids, input_mask) - def _build_model(self, src_ids, position_ids, sentence_ids, self_attn_mask): + def _build_model(self, src_ids, position_ids, sentence_ids, input_mask): # padding id in vocabulary must be set to 0 emb_out = fluid.layers.embedding( input=src_ids, @@ -110,9 +110,12 @@ class ErnieModel(object): emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder') if self._dtype == "float16": - self_attn_mask = fluid.layers.cast( - x=self_attn_mask, dtype=self._dtype) + input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype) + self_attn_mask = fluid.layers.matmul( + x=input_mask, y=input_mask, transpose_y=True) + self_attn_mask = fluid.layers.scale( + x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False) n_head_self_attn_mask = fluid.layers.stack( x=[self_attn_mask] * self._n_head, axis=1) n_head_self_attn_mask.stop_gradient = True @@ -138,13 +141,10 @@ class ErnieModel(object): def get_sequence_output(self): return self._enc_out - def get_pooled_output(self, next_sent_index): + def get_pooled_output(self): """Get the first feature of each sequence for classification""" - self._reshaped_emb_out = fluid.layers.reshape( - x=self._enc_out, shape=[-1, self._emb_size], inplace=True) - next_sent_index = fluid.layers.cast(x=next_sent_index, dtype='int32') - next_sent_feat = fluid.layers.gather( - input=self._reshaped_emb_out, index=next_sent_index) + next_sent_feat = fluid.layers.slice( + input=self._enc_out, axes=[1], starts=[0], ends=[1]) next_sent_feat = fluid.layers.fc( input=next_sent_feat, size=self._emb_size, @@ -154,17 +154,17 @@ class ErnieModel(object): bias_attr="pooled_fc.b_0") return next_sent_feat - def get_pretraining_output(self, mask_label, mask_pos, labels, - next_sent_index): + def get_pretraining_output(self, mask_label, mask_pos, labels): """Get the loss & accuracy for pretraining""" mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32') # extract the first token feature in each sentence - next_sent_feat = self.get_pooled_output(next_sent_index) + next_sent_feat = self.get_pooled_output() + reshaped_emb_out = fluid.layers.reshape( + x=self._enc_out, shape=[-1, self._emb_size]) # extract masked tokens' feature - mask_feat = fluid.layers.gather( - input=self._reshaped_emb_out, index=mask_pos) + mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos) # transform: fc mask_trans_feat = fluid.layers.fc( diff --git a/ERNIE/reader/pretraining.py b/ERNIE/reader/pretraining.py index 67d0e593591110db785191ad3f38b67c7a3e0d2d..c1233ad014ac9e8300bf33dc344de1e8c5a69c40 100644 --- a/ERNIE/reader/pretraining.py +++ b/ERNIE/reader/pretraining.py @@ -171,9 +171,12 @@ class ErnieDataReader(object): if len(token_seq) > self.max_seq_len: miss_num += 1 continue - type_seq = [0] * (len(left_tokens) + 2) + [1] * (len(right_tokens) + 1) + type_seq = [0] * (len(left_tokens) + 2) + [1] * (len(right_tokens) + + 1) pos_seq = range(len(token_seq)) - seg_label_seq = [-1] + left_seg_labels + [-1] + right_seg_labels + [-1] + seg_label_seq = [-1] + left_seg_labels + [-1] + right_seg_labels + [ + -1 + ] assert len(token_seq) == len(type_seq) == len(pos_seq) == len(seg_label_seq), \ "[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True" @@ -290,7 +293,7 @@ class ErnieDataReader(object): cls_id=self.cls_id, sep_id=self.sep_id, mask_id=self.mask_id, - return_attn_bias=True, + return_input_mask=True, return_max_len=False, return_num_token=False) diff --git a/ERNIE/reader/task_reader.py b/ERNIE/reader/task_reader.py index 74130d6c8ff17a1a715d231ba68e10561c11622a..664e6d635a3ab4eeba04f504829099d34fde6d71 100644 --- a/ERNIE/reader/task_reader.py +++ b/ERNIE/reader/task_reader.py @@ -247,11 +247,8 @@ class ClassifyReader(BaseReader): batch_qids = np.array([]).astype("int64").reshape([-1, 1]) # padding - padded_token_ids, next_sent_index, self_attn_bias = pad_batch_data( - batch_token_ids, - pad_idx=self.pad_id, - return_next_sent_pos=True, - return_attn_bias=True) + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, pad_idx=self.pad_id, return_input_mask=True) padded_text_type_ids = pad_batch_data( batch_text_type_ids, pad_idx=self.pad_id) padded_position_ids = pad_batch_data( @@ -259,7 +256,7 @@ class ClassifyReader(BaseReader): return_list = [ padded_token_ids, padded_text_type_ids, padded_position_ids, - self_attn_bias, batch_labels, next_sent_index, batch_qids + input_mask, batch_labels, batch_qids ] return return_list @@ -274,11 +271,8 @@ class SequenceLabelReader(BaseReader): batch_seq_lens = [len(record.token_ids) for record in batch_records] # padding - padded_token_ids, self_attn_bias = pad_batch_data( - batch_token_ids, - pad_idx=self.pad_id, - return_next_sent_pos=False, - return_attn_bias=True) + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, pad_idx=self.pad_id, return_input_mask=True) padded_text_type_ids = pad_batch_data( batch_text_type_ids, pad_idx=self.pad_id) padded_position_ids = pad_batch_data( @@ -290,7 +284,7 @@ class SequenceLabelReader(BaseReader): return_list = [ padded_token_ids, padded_text_type_ids, padded_position_ids, - self_attn_bias, padded_label_ids, batch_seq_lens + input_mask, padded_label_ids, batch_seq_lens ] return return_list diff --git a/ERNIE/train.py b/ERNIE/train.py index e696c896cd071e756209cc5b62c610e53316db32..665ab85e6114b9c9e7d6e8b9a5212dfb32c66661 100644 --- a/ERNIE/train.py +++ b/ERNIE/train.py @@ -43,31 +43,29 @@ def create_model(pyreader_name, ernie_config): pyreader = fluid.layers.py_reader( capacity=70, shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], - [-1, args.max_seq_len, 1], - [-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1], + [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1], [-1, 1], [-1, 1]], dtypes=[ - 'int64', 'int64', 'int64', 'float', 'int64', 'int64', 'int64', - 'int64' + 'int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64' ], - lod_levels=[0, 0, 0, 0, 0, 0, 0, 0], + lod_levels=[0, 0, 0, 0, 0, 0, 0], name=pyreader_name, use_double_buffer=True) - (src_ids, pos_ids, sent_ids, self_attn_mask, mask_label, mask_pos, labels, - next_sent_index) = fluid.layers.read_file(pyreader) + (src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, + labels) = fluid.layers.read_file(pyreader) ernie = ErnieModel( src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, - self_attn_mask=self_attn_mask, + input_mask=input_mask, config=ernie_config, weight_sharing=args.weight_sharing, use_fp16=args.use_fp16) next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output( - mask_label, mask_pos, labels, next_sent_index) + mask_label, mask_pos, labels) if args.use_fp16 and args.loss_scaling > 1.0: total_loss *= args.loss_scaling