提交 388c1c83 编写于 作者: Y Yibing Liu

Simplify ernie model structure

上级 92f5f78f
...@@ -124,7 +124,7 @@ def prepare_batch_data(insts, ...@@ -124,7 +124,7 @@ def prepare_batch_data(insts,
cls_id=None, cls_id=None,
sep_id=None, sep_id=None,
mask_id=None, mask_id=None,
return_attn_bias=True, return_input_mask=True,
return_max_len=True, return_max_len=True,
return_num_token=False): return_num_token=False):
...@@ -149,14 +149,13 @@ def prepare_batch_data(insts, ...@@ -149,14 +149,13 @@ def prepare_batch_data(insts,
MASK=mask_id) MASK=mask_id)
# Second step: padding # Second step: padding
src_id, next_sent_index, self_attn_bias = pad_batch_data( src_id, self_input_mask = pad_batch_data(
out, pad_idx=pad_id, return_next_sent_pos=True, return_attn_bias=True) out, pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id) pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id) sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
return_list = [ return_list = [
src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels, src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos, labels
next_sent_index
] ]
return return_list return return_list
...@@ -165,8 +164,7 @@ def prepare_batch_data(insts, ...@@ -165,8 +164,7 @@ def prepare_batch_data(insts,
def pad_batch_data(insts, def pad_batch_data(insts,
pad_idx=0, pad_idx=0,
return_pos=False, return_pos=False,
return_next_sent_pos=False, return_input_mask=False,
return_attn_bias=False,
return_max_len=False, return_max_len=False,
return_num_token=False): return_num_token=False):
""" """
...@@ -182,15 +180,6 @@ def pad_batch_data(insts, ...@@ -182,15 +180,6 @@ def pad_batch_data(insts,
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# next_sent_pos for extract first token embedding of each sentence
if return_next_sent_pos:
batch_size = inst_data.shape[0]
max_seq_len = inst_data.shape[1]
next_sent_index = np.array(
range(0, batch_size * max_seq_len, max_seq_len)).astype(
"int64").reshape(-1, 1)
return_list += [next_sent_index]
# position data # position data
if return_pos: if return_pos:
inst_pos = np.array([ inst_pos = np.array([
...@@ -200,13 +189,12 @@ def pad_batch_data(insts, ...@@ -200,13 +189,12 @@ def pad_batch_data(insts,
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_attn_bias: if return_input_mask:
# This is used to avoid attention on paddings. # This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts]) (max_len - len(inst)) for inst in insts])
slf_attn_bias_data = np.tile( input_mask_data = np.expand_dims(input_mask_data, axis=-1)
slf_attn_bias_data.reshape([-1, 1, max_len]), [1, max_len, 1]) return_list += [input_mask_data.astype("float32")]
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len: if return_max_len:
return_list += [max_len] return_list += [max_len]
......
...@@ -31,26 +31,25 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -31,26 +31,25 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
capacity=50, capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1],
[-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1],
[-1, 1]], [-1, 1]],
dtypes=['int64', 'int64', 'int64', 'float', 'int64', 'int64', 'int64'], dtypes=['int64', 'int64', 'int64', 'float32', 'int64', 'int64'],
lod_levels=[0, 0, 0, 0, 0, 0, 0], lod_levels=[0, 0, 0, 0, 0, 0],
name=pyreader_name, name=pyreader_name,
use_double_buffer=True) use_double_buffer=True)
(src_ids, sent_ids, pos_ids, self_attn_mask, labels, next_sent_index, (src_ids, sent_ids, pos_ids, input_mask, labels,
qids) = fluid.layers.read_file(pyreader) qids) = fluid.layers.read_file(pyreader)
ernie = ErnieModel( ernie = ErnieModel(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
sentence_ids=sent_ids, sentence_ids=sent_ids,
self_attn_mask=self_attn_mask, input_mask=input_mask,
config=ernie_config, config=ernie_config,
use_fp16=args.use_fp16) use_fp16=args.use_fp16)
cls_feats = ernie.get_pooled_output(next_sent_index) cls_feats = ernie.get_pooled_output()
cls_feats = fluid.layers.dropout( cls_feats = fluid.layers.dropout(
x=cls_feats, x=cls_feats,
dropout_prob=0.1, dropout_prob=0.1,
...@@ -67,8 +66,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -67,8 +66,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
if is_prediction: if is_prediction:
probs = fluid.layers.softmax(logits) probs = fluid.layers.softmax(logits)
feed_targets_name = [ feed_targets_name = [
src_ids.name, pos_ids.name, sent_ids.name, self_attn_mask.name, src_ids.name, pos_ids.name, sent_ids.name, input_mask.name
next_sent_index.name
] ]
return pyreader, probs, feed_targets_name return pyreader, probs, feed_targets_name
......
...@@ -29,28 +29,26 @@ from six.moves import xrange ...@@ -29,28 +29,26 @@ from six.moves import xrange
from model.ernie import ErnieModel from model.ernie import ErnieModel
def create_model(args,
pyreader_name, def create_model(args, pyreader_name, ernie_config, is_prediction=False):
ernie_config,
is_prediction=False):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
capacity=50, capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, args.max_seq_len], [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, 1]], [-1, args.max_seq_len, 1], [-1, 1]],
dtypes=['int64', 'int64', 'int64', 'float', 'int64', 'int64'], dtypes=['int64', 'int64', 'int64', 'float32', 'int64', 'int64'],
lod_levels=[0, 0, 0, 0, 0, 0], lod_levels=[0, 0, 0, 0, 0, 0],
name=pyreader_name, name=pyreader_name,
use_double_buffer=True) use_double_buffer=True)
(src_ids, sent_ids, pos_ids, self_attn_mask, labels, (src_ids, sent_ids, pos_ids, input_mask, labels,
seq_lens) = fluid.layers.read_file(pyreader) seq_lens) = fluid.layers.read_file(pyreader)
ernie = ErnieModel( ernie = ErnieModel(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
sentence_ids=sent_ids, sentence_ids=sent_ids,
self_attn_mask=self_attn_mask, input_mask=input_mask,
config=ernie_config, config=ernie_config,
use_fp16=args.use_fp16) use_fp16=args.use_fp16)
...@@ -63,33 +61,40 @@ def create_model(args, ...@@ -63,33 +61,40 @@ def create_model(args,
name="cls_seq_label_out_w", name="cls_seq_label_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)), initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr( bias_attr=fluid.ParamAttr(
name="cls_seq_label_out_b", initializer=fluid.initializer.Constant(0.))) name="cls_seq_label_out_b",
initializer=fluid.initializer.Constant(0.)))
ret_labels = fluid.layers.reshape(x=labels, shape=[-1,1]) ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1])
ret_infers = fluid.layers.reshape(x=fluid.layers.argmax(logits, axis=2), shape=[-1,1]) ret_infers = fluid.layers.reshape(
x=fluid.layers.argmax(
logits, axis=2), shape=[-1, 1])
labels = fluid.layers.flatten(labels, axis=2) labels = fluid.layers.flatten(labels, axis=2)
ce_loss, probs = fluid.layers.softmax_with_cross_entropy( ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=fluid.layers.flatten(logits, axis=2), logits=fluid.layers.flatten(
label=labels, return_softmax=True) logits, axis=2),
label=labels,
return_softmax=True)
loss = fluid.layers.mean(x=ce_loss) loss = fluid.layers.mean(x=ce_loss)
if args.use_fp16 and args.loss_scaling > 1.0: if args.use_fp16 and args.loss_scaling > 1.0:
loss *= args.loss_scaling loss *= args.loss_scaling
graph_vars = {"loss": loss, graph_vars = {
"probs": probs, "loss": loss,
"labels": ret_labels, "probs": probs,
"infers": ret_infers, "labels": ret_labels,
"seq_lens": seq_lens} "infers": ret_infers,
"seq_lens": seq_lens
}
for k, v in graph_vars.items(): for k, v in graph_vars.items():
v.persistable=True v.persistable = True
return pyreader, graph_vars return pyreader, graph_vars
def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
def extract_bio_chunk(seq): def extract_bio_chunk(seq):
chunks = [] chunks = []
cur_chunk = None cur_chunk = None
...@@ -109,18 +114,18 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): ...@@ -109,18 +114,18 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
if cur_chunk is not None: if cur_chunk is not None:
chunks.append(cur_chunk) chunks.append(cur_chunk)
cur_chunk = {} cur_chunk = {}
cur_chunk = {"st":index, "en": index + 1, "type": tag_type} cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else: else:
if cur_chunk is None: if cur_chunk is None:
cur_chunk = {"st":index, "en": index + 1, "type": tag_type} cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
continue continue
if cur_chunk["type"] == tag_type: if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1 cur_chunk["en"] = index + 1
else: else:
chunks.append(cur_chunk) chunks.append(cur_chunk)
cur_chunk = {"st":index, "en": index + 1, "type": tag_type} cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
if cur_chunk is not None: if cur_chunk is not None:
chunks.append(cur_chunk) chunks.append(cur_chunk)
...@@ -151,10 +156,13 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): ...@@ -151,10 +156,13 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
infer_index = 0 infer_index = 0
label_index = 0 label_index = 0
while label_index < len(label_chunks) and infer_index < len(infer_chunks): while label_index < len(label_chunks) and infer_index < len(
if infer_chunks[infer_index]["st"] < label_chunks[label_index]["st"]: infer_chunks):
if infer_chunks[infer_index]["st"] < label_chunks[label_index][
"st"]:
infer_index += 1 infer_index += 1
elif infer_chunks[infer_index]["st"] > label_chunks[label_index]["st"]: elif infer_chunks[infer_index]["st"] > label_chunks[
label_index]["st"]:
label_index += 1 label_index += 1
else: else:
if infer_chunks[infer_index]["en"] == label_chunks[label_index]["en"] and \ if infer_chunks[infer_index]["en"] == label_chunks[label_index]["en"] and \
...@@ -168,6 +176,7 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): ...@@ -168,6 +176,7 @@ def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
return num_label, num_infer, num_correct return num_label, num_infer, num_correct
def calculate_f1(num_label, num_infer, num_correct): def calculate_f1(num_label, num_infer, num_correct):
if num_infer == 0: if num_infer == 0:
precision = 0.0 precision = 0.0
...@@ -185,10 +194,18 @@ def calculate_f1(num_label, num_infer, num_correct): ...@@ -185,10 +194,18 @@ def calculate_f1(num_label, num_infer, num_correct):
f1 = 2 * precision * recall / (precision + recall) f1 = 2 * precision * recall / (precision + recall)
return precision, recall, f1 return precision, recall, f1
def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count=1):
fetch_list = [graph_vars["labels"].name, def evaluate(exe,
graph_vars["infers"].name, program,
graph_vars["seq_lens"].name] pyreader,
graph_vars,
tag_num,
eval_phase,
dev_count=1):
fetch_list = [
graph_vars["labels"].name, graph_vars["infers"].name,
graph_vars["seq_lens"].name
]
if eval_phase == "train": if eval_phase == "train":
fetch_list.append(graph_vars["loss"].name) fetch_list.append(graph_vars["loss"].name)
...@@ -196,9 +213,15 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= ...@@ -196,9 +213,15 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count=
fetch_list.append(graph_vars["learning_rate"].name) fetch_list.append(graph_vars["learning_rate"].name)
outputs = exe.run(fetch_list=fetch_list) outputs = exe.run(fetch_list=fetch_list)
np_labels, np_infers, np_lens, np_loss = outputs[:4] np_labels, np_infers, np_lens, np_loss = outputs[:4]
num_label, num_infer, num_correct = chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count) num_label, num_infer, num_correct = chunk_eval(
np_labels, np_infers, np_lens, tag_num, dev_count)
precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct) precision, recall, f1 = calculate_f1(num_label, num_infer, num_correct)
outputs = {"precision": precision, "recall": recall, "f1": f1, "loss": np.mean(np_loss)} outputs = {
"precision": precision,
"recall": recall,
"f1": f1,
"loss": np.mean(np_loss)
}
if "learning_rate" in graph_vars: if "learning_rate" in graph_vars:
outputs["lr"] = float(outputs[4][0]) outputs["lr"] = float(outputs[4][0])
return outputs return outputs
...@@ -209,8 +232,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= ...@@ -209,8 +232,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count=
pyreader.start() pyreader.start()
while True: while True:
try: try:
np_labels, np_infers, np_lens = exe.run(program=program, fetch_list=fetch_list) np_labels, np_infers, np_lens = exe.run(program=program,
label_num, infer_num, correct_num = chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count) fetch_list=fetch_list)
label_num, infer_num, correct_num = chunk_eval(
np_labels, np_infers, np_lens, tag_num, dev_count)
total_infer += infer_num total_infer += infer_num
total_label += label_num total_label += label_num
total_correct += correct_num total_correct += correct_num
...@@ -219,8 +244,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count= ...@@ -219,8 +244,10 @@ def evaluate(exe, program, pyreader, graph_vars, tag_num, eval_phase, dev_count=
pyreader.reset() pyreader.reset()
break break
precision, recall, f1 = calculate_f1(total_label, total_infer, total_correct) precision, recall, f1 = calculate_f1(total_label, total_infer,
total_correct)
time_end = time.time() time_end = time.time()
print("[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s" % print(
(eval_phase, f1, precision, recall, time_end - time_begin)) "[%s evaluation] f1: %f, precision: %f, recall: %f, elapsed time: %f s"
% (eval_phase, f1, precision, recall, time_end - time_begin))
...@@ -52,7 +52,7 @@ class ErnieModel(object): ...@@ -52,7 +52,7 @@ class ErnieModel(object):
src_ids, src_ids,
position_ids, position_ids,
sentence_ids, sentence_ids,
self_attn_mask, input_mask,
config, config,
weight_sharing=True, weight_sharing=True,
use_fp16=False): use_fp16=False):
...@@ -78,9 +78,9 @@ class ErnieModel(object): ...@@ -78,9 +78,9 @@ class ErnieModel(object):
self._param_initializer = fluid.initializer.TruncatedNormal( self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range']) scale=config['initializer_range'])
self._build_model(src_ids, position_ids, sentence_ids, self_attn_mask) self._build_model(src_ids, position_ids, sentence_ids, input_mask)
def _build_model(self, src_ids, position_ids, sentence_ids, self_attn_mask): def _build_model(self, src_ids, position_ids, sentence_ids, input_mask):
# padding id in vocabulary must be set to 0 # padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding( emb_out = fluid.layers.embedding(
input=src_ids, input=src_ids,
...@@ -110,9 +110,12 @@ class ErnieModel(object): ...@@ -110,9 +110,12 @@ class ErnieModel(object):
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder') emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
if self._dtype == "float16": if self._dtype == "float16":
self_attn_mask = fluid.layers.cast( input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
x=self_attn_mask, dtype=self._dtype) self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=1000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack( n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1) x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True n_head_self_attn_mask.stop_gradient = True
...@@ -138,13 +141,10 @@ class ErnieModel(object): ...@@ -138,13 +141,10 @@ class ErnieModel(object):
def get_sequence_output(self): def get_sequence_output(self):
return self._enc_out return self._enc_out
def get_pooled_output(self, next_sent_index): def get_pooled_output(self):
"""Get the first feature of each sequence for classification""" """Get the first feature of each sequence for classification"""
self._reshaped_emb_out = fluid.layers.reshape( next_sent_feat = fluid.layers.slice(
x=self._enc_out, shape=[-1, self._emb_size], inplace=True) input=self._enc_out, axes=[1], starts=[0], ends=[1])
next_sent_index = fluid.layers.cast(x=next_sent_index, dtype='int32')
next_sent_feat = fluid.layers.gather(
input=self._reshaped_emb_out, index=next_sent_index)
next_sent_feat = fluid.layers.fc( next_sent_feat = fluid.layers.fc(
input=next_sent_feat, input=next_sent_feat,
size=self._emb_size, size=self._emb_size,
...@@ -154,17 +154,17 @@ class ErnieModel(object): ...@@ -154,17 +154,17 @@ class ErnieModel(object):
bias_attr="pooled_fc.b_0") bias_attr="pooled_fc.b_0")
return next_sent_feat return next_sent_feat
def get_pretraining_output(self, mask_label, mask_pos, labels, def get_pretraining_output(self, mask_label, mask_pos, labels):
next_sent_index):
"""Get the loss & accuracy for pretraining""" """Get the loss & accuracy for pretraining"""
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32') mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
# extract the first token feature in each sentence # extract the first token feature in each sentence
next_sent_feat = self.get_pooled_output(next_sent_index) next_sent_feat = self.get_pooled_output()
reshaped_emb_out = fluid.layers.reshape(
x=self._enc_out, shape=[-1, self._emb_size])
# extract masked tokens' feature # extract masked tokens' feature
mask_feat = fluid.layers.gather( mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
input=self._reshaped_emb_out, index=mask_pos)
# transform: fc # transform: fc
mask_trans_feat = fluid.layers.fc( mask_trans_feat = fluid.layers.fc(
......
...@@ -171,9 +171,12 @@ class ErnieDataReader(object): ...@@ -171,9 +171,12 @@ class ErnieDataReader(object):
if len(token_seq) > self.max_seq_len: if len(token_seq) > self.max_seq_len:
miss_num += 1 miss_num += 1
continue continue
type_seq = [0] * (len(left_tokens) + 2) + [1] * (len(right_tokens) + 1) type_seq = [0] * (len(left_tokens) + 2) + [1] * (len(right_tokens) +
1)
pos_seq = range(len(token_seq)) pos_seq = range(len(token_seq))
seg_label_seq = [-1] + left_seg_labels + [-1] + right_seg_labels + [-1] seg_label_seq = [-1] + left_seg_labels + [-1] + right_seg_labels + [
-1
]
assert len(token_seq) == len(type_seq) == len(pos_seq) == len(seg_label_seq), \ assert len(token_seq) == len(type_seq) == len(pos_seq) == len(seg_label_seq), \
"[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True" "[ERROR]len(src_id) == lne(sent_id) == len(pos_id) must be True"
...@@ -290,7 +293,7 @@ class ErnieDataReader(object): ...@@ -290,7 +293,7 @@ class ErnieDataReader(object):
cls_id=self.cls_id, cls_id=self.cls_id,
sep_id=self.sep_id, sep_id=self.sep_id,
mask_id=self.mask_id, mask_id=self.mask_id,
return_attn_bias=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False)
......
...@@ -247,11 +247,8 @@ class ClassifyReader(BaseReader): ...@@ -247,11 +247,8 @@ class ClassifyReader(BaseReader):
batch_qids = np.array([]).astype("int64").reshape([-1, 1]) batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding # padding
padded_token_ids, next_sent_index, self_attn_bias = pad_batch_data( padded_token_ids, input_mask = pad_batch_data(
batch_token_ids, batch_token_ids, pad_idx=self.pad_id, return_input_mask=True)
pad_idx=self.pad_id,
return_next_sent_pos=True,
return_attn_bias=True)
padded_text_type_ids = pad_batch_data( padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id) batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data( padded_position_ids = pad_batch_data(
...@@ -259,7 +256,7 @@ class ClassifyReader(BaseReader): ...@@ -259,7 +256,7 @@ class ClassifyReader(BaseReader):
return_list = [ return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids, padded_token_ids, padded_text_type_ids, padded_position_ids,
self_attn_bias, batch_labels, next_sent_index, batch_qids input_mask, batch_labels, batch_qids
] ]
return return_list return return_list
...@@ -274,11 +271,8 @@ class SequenceLabelReader(BaseReader): ...@@ -274,11 +271,8 @@ class SequenceLabelReader(BaseReader):
batch_seq_lens = [len(record.token_ids) for record in batch_records] batch_seq_lens = [len(record.token_ids) for record in batch_records]
# padding # padding
padded_token_ids, self_attn_bias = pad_batch_data( padded_token_ids, input_mask = pad_batch_data(
batch_token_ids, batch_token_ids, pad_idx=self.pad_id, return_input_mask=True)
pad_idx=self.pad_id,
return_next_sent_pos=False,
return_attn_bias=True)
padded_text_type_ids = pad_batch_data( padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id) batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data( padded_position_ids = pad_batch_data(
...@@ -290,7 +284,7 @@ class SequenceLabelReader(BaseReader): ...@@ -290,7 +284,7 @@ class SequenceLabelReader(BaseReader):
return_list = [ return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids, padded_token_ids, padded_text_type_ids, padded_position_ids,
self_attn_bias, padded_label_ids, batch_seq_lens input_mask, padded_label_ids, batch_seq_lens
] ]
return return_list return return_list
......
...@@ -43,31 +43,29 @@ def create_model(pyreader_name, ernie_config): ...@@ -43,31 +43,29 @@ def create_model(pyreader_name, ernie_config):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
capacity=70, capacity=70,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1], [-1, 1],
[-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1],
[-1, 1], [-1, 1]], [-1, 1], [-1, 1]],
dtypes=[ dtypes=[
'int64', 'int64', 'int64', 'float', 'int64', 'int64', 'int64', 'int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64'
'int64'
], ],
lod_levels=[0, 0, 0, 0, 0, 0, 0, 0], lod_levels=[0, 0, 0, 0, 0, 0, 0],
name=pyreader_name, name=pyreader_name,
use_double_buffer=True) use_double_buffer=True)
(src_ids, pos_ids, sent_ids, self_attn_mask, mask_label, mask_pos, labels, (src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos,
next_sent_index) = fluid.layers.read_file(pyreader) labels) = fluid.layers.read_file(pyreader)
ernie = ErnieModel( ernie = ErnieModel(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
sentence_ids=sent_ids, sentence_ids=sent_ids,
self_attn_mask=self_attn_mask, input_mask=input_mask,
config=ernie_config, config=ernie_config,
weight_sharing=args.weight_sharing, weight_sharing=args.weight_sharing,
use_fp16=args.use_fp16) use_fp16=args.use_fp16)
next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output( next_sent_acc, mask_lm_loss, total_loss = ernie.get_pretraining_output(
mask_label, mask_pos, labels, next_sent_index) mask_label, mask_pos, labels)
if args.use_fp16 and args.loss_scaling > 1.0: if args.use_fp16 and args.loss_scaling > 1.0:
total_loss *= args.loss_scaling total_loss *= args.loss_scaling
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册