提交 53c68177 编写于 作者: X xixiaoyao

add mlm task and fix bugs

上级 dc1c43e8
task_instance: "mrqa, match4mrqa" task_instance: "mrqa, mlm4mrqa, match4mrqa"
target_tag: 1, 0 target_tag: 1, 0, 0
mix_ratio: 1.0, 0.5 mix_ratio: 0.5, 1.0, 0.5
save_path: "output_model/secondrun" save_path: "output_model/secondrun"
...@@ -12,9 +12,10 @@ do_lower_case: True ...@@ -12,9 +12,10 @@ do_lower_case: True
max_seq_len: 512 max_seq_len: 512
batch_size: 5 batch_size: 5
num_epochs: 2 num_epochs: 5
optimizer: "adam" optimizer: "adam"
learning_rate: 3e-5 learning_rate: 3e-5
warmup_proportion: 0.1 warmup_proportion: 0.1
weight_decay: 0.1 weight_decay: 0.1
print_every_n_steps: 1
此差异已折叠。
import paddlepalm as palm import paddlepalm as palm
if __name__ == '__main__': if __name__ == '__main__':
controller = palm.Controller('demo2_config.yaml', task_dir='demo2_tasks') controller = palm.Controller('config_demo2.yaml', task_dir='demo2_tasks')
controller.load_pretrain('pretrain_model/ernie/params') controller.load_pretrain('pretrain_model/ernie/params')
controller.train() controller.train()
controller = palm.Controller(config='demo2_config.yaml', task_dir='demo2_tasks', for_train=False) controller = palm.Controller(config='config_demo2.yaml', task_dir='demo2_tasks', for_train=False)
controller.pred('mrqa', inference_model_dir='output_model/secondrun/infer_model') controller.pred('mrqa', inference_model_dir='output_model/secondrun/infer_model')
train_file: "data/mlm4mrqa/train.txt"
reader: mlm
paradigm: mlm
...@@ -28,9 +28,7 @@ from paddlepalm.interface import backbone ...@@ -28,9 +28,7 @@ from paddlepalm.interface import backbone
class Model(backbone): class Model(backbone):
def __init__(self, def __init__(self, config, phase):
config,
phase):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变 # self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
self._emb_size = config["hidden_size"] self._emb_size = config["hidden_size"]
...@@ -56,16 +54,17 @@ class Model(backbone): ...@@ -56,16 +54,17 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
return {"token_ids": [-1, self._max_position_seq_len, 1], 'int64'], return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [-1, self._max_position_seq_len, 1], 'int64'], "position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [-1, self._max_position_seq_len, 1], 'int64'], "segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [-1, self._max_position_seq_len, 1], 'float32']} "input_mask": [[-1, -1, 1], 'float32']}
@property @property
def outputs_attr(self): def outputs_attr(self):
return {"word_emb": [-1, self._max_position_seq_len, self._emb_size], return {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"sentence_emb": [-1, self._emb_size], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_pair_emb": [-1, self._emb_size]} "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
def build(self, inputs): def build(self, inputs):
src_ids = inputs['token_ids'] src_ids = inputs['token_ids']
...@@ -146,9 +145,10 @@ class Model(backbone): ...@@ -146,9 +145,10 @@ class Model(backbone):
initializer = self._param_initializer), initializer = self._param_initializer),
bias_attr = "pooled_fc.b_0") bias_attr = "pooled_fc.b_0")
return {'word_emb': enc_out, return {'word_embedding': emb_out,
'sentence_emb': next_sent_feat, 'encoder_outputs': enc_out,
'sentence_pair_emb': next_sent_feat} 'sentence_embedding': next_sent_feat,
'sentence_pair_embedding': next_sent_feat}
def postprocess(self, rt_outputs): def postprocess(self, rt_outputs):
pass pass
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid import layers
class Model(backbone):
def __init__(self, config, phase):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
self._emb_size = config["emb_size"]
self._voc_size = config["vocab_size"]
@property
def inputs_attr(self):
return {"token_ids": [-1, self._max_position_seq_len, 1], 'int64']}
@property
def outputs_attr(self):
return {"word_emb": [-1, self._max_position_seq_len, self._emb_size],
"sentence_emb": [-1, self._emb_size*2]}
def build(self, inputs):
tok_ids = inputs['token_ids']
emb_out = layers.embedding(
input=tok_ids,
size=[self._voc_size, self._emb_size],
dtype='float32',
param_attr=fluid.ParamAttr(
name='word_emb',
initializer=fluid.initializer.TruncatedNormal(scale=0.1)),
is_sparse=False)
sent_emb1 = layers.reduce_mean(emb_out, axis=1)
sent_emb2 = layers.reduce_max(emb_out, axis=1)
sent_emb = layers.concat([sent_emb1, sent_emb2], axis=1)
return {'word_emb': emb_out,
'sentence_emb': sent_emb}
def postprocess(self, rt_outputs):
pass
...@@ -71,6 +71,7 @@ class Model(backbone): ...@@ -71,6 +71,7 @@ class Model(backbone):
@property @property
def outputs_attr(self): def outputs_attr(self):
return {"word_embedding": [[-1, -1, self._emb_size], 'float32'], return {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'], "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']} "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
...@@ -91,6 +92,9 @@ class Model(backbone): ...@@ -91,6 +92,9 @@ class Model(backbone):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer), name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False) is_sparse=False)
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
embedding_table = fluid.default_main_program().global_block().var(self._word_emb_name)
position_emb_out = fluid.layers.embedding( position_emb_out = fluid.layers.embedding(
input=pos_ids, input=pos_ids,
...@@ -161,7 +165,8 @@ class Model(backbone): ...@@ -161,7 +165,8 @@ class Model(backbone):
name="pooled_fc.w_0", initializer=self._param_initializer), name="pooled_fc.w_0", initializer=self._param_initializer),
bias_attr="pooled_fc.b_0") bias_attr="pooled_fc.b_0")
return {'word_embedding': emb_out, return {'embedding_table': embedding_table,
'word_embedding': emb_out,
'encoder_outputs': enc_out, 'encoder_outputs': enc_out,
'sentence_embedding': next_sent_feat, 'sentence_embedding': next_sent_feat,
'sentence_pair_embedding': next_sent_feat} 'sentence_pair_embedding': next_sent_feat}
......
...@@ -422,7 +422,7 @@ class Controller(object): ...@@ -422,7 +422,7 @@ class Controller(object):
prefixes.append(inst.name) prefixes.append(inst.name)
mrs.append(inst.mix_ratio) mrs.append(inst.mix_ratio)
joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, batch_size=main_conf['batch_size']) joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE)
input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)]
pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)]
...@@ -653,6 +653,10 @@ class Controller(object): ...@@ -653,6 +653,10 @@ class Controller(object):
loss, main_conf.get('print_every_n_steps', 5) / time_cost)) loss, main_conf.get('print_every_n_steps', 5) / time_cost))
time_begin = time.time() time_begin = time.time()
if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
print(cur_task.name+': train finished!')
cur_task.save()
if 'save_every_n_steps' in main_conf and global_step % main_conf['save_every_n_steps'] == 0: if 'save_every_n_steps' in main_conf and global_step % main_conf['save_every_n_steps'] == 0:
save_path = os.path.join(main_conf['save_path'], save_path = os.path.join(main_conf['save_path'],
"step_" + str(global_step)) "step_" + str(global_step))
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import BaseReader from paddlepalm.reader.utils.reader4ernie import MaskLMReader
class Reader(reader): class Reader(reader):
...@@ -26,7 +26,7 @@ class Reader(reader): ...@@ -26,7 +26,7 @@ class Reader(reader):
self._is_training = phase == 'train' self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'], reader = MaskLMReader(config['vocab_path'],
max_seq_len=config['max_seq_len'], max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False), do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False), for_cn=config.get('for_cn', False),
...@@ -59,21 +59,14 @@ class Reader(reader): ...@@ -59,21 +59,14 @@ class Reader(reader):
@property @property
def outputs_attr(self): def outputs_attr(self):
if self._is_training: return {"token_ids": [[-1, -1, 1], 'int64'],
return {"token_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'], "input_mask": [[-1, -1, 1], 'float32'],
"input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1, -1, 1], 'int64'],
"label_ids": [[-1,1], 'int64'], "mask_label": [[-1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'] "mask_pos": [[-1, 1], 'int64']
} }
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self): def load_data(self):
...@@ -85,9 +78,6 @@ class Reader(reader): ...@@ -85,9 +78,6 @@ class Reader(reader):
names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask', names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask',
'task_ids', 'mask_label', 'mask_pos'] 'task_ids', 'mask_label', 'mask_pos']
outputs = {n: i for n,i in zip(names, x)} outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
......
...@@ -93,6 +93,7 @@ def prepare_batch_data(insts, ...@@ -93,6 +93,7 @@ def prepare_batch_data(insts,
batch_sent_ids = [inst[1] for inst in insts] batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts] batch_pos_ids = [inst[2] for inst in insts]
# 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的?
# First step: do mask without padding # First step: do mask without padding
out, mask_label, mask_pos = mask( out, mask_label, mask_pos = mask(
batch_src_ids, batch_src_ids,
...@@ -106,6 +107,7 @@ def prepare_batch_data(insts, ...@@ -106,6 +107,7 @@ def prepare_batch_data(insts,
out, out,
max_len=max_len, max_len=max_len,
pad_idx=pad_id, return_input_mask=True) pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data( pos_id = pad_batch_data(
batch_pos_ids, batch_pos_ids,
max_len=max_len, max_len=max_len,
......
...@@ -45,11 +45,7 @@ if six.PY3: ...@@ -45,11 +45,7 @@ if six.PY3:
def csv_reader(fd, delimiter='\t'): def csv_reader(fd, delimiter='\t'):
def gen(): def gen():
for i in fd: for i in fd:
slots = i.rstrip('\n').split(delimiter) yield i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen() return gen()
...@@ -74,6 +70,7 @@ class BaseReader(object): ...@@ -74,6 +70,7 @@ class BaseReader(object):
self.pad_id = self.vocab["[PAD]"] self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"] self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"] self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.in_tokens = in_tokens self.in_tokens = in_tokens
self.is_inference = is_inference self.is_inference = is_inference
self.for_cn = for_cn self.for_cn = for_cn
...@@ -242,7 +239,6 @@ class BaseReader(object): ...@@ -242,7 +239,6 @@ class BaseReader(object):
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records: if phase == 'pred' and batch_records:
print('the last batch yielded.')
yield self._pad_batch_records(batch_records) yield self._pad_batch_records(batch_records)
def get_num_examples(self, input_file=None, phase=None): def get_num_examples(self, input_file=None, phase=None):
...@@ -371,31 +367,28 @@ class MaskLMReader(BaseReader): ...@@ -371,31 +367,28 @@ class MaskLMReader(BaseReader):
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids))) position_ids = list(range(len(token_ids)))
Record = namedtuple('Record', # Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids']) # ['token_ids', 'text_type_ids', 'position_ids'])
record = Record( # record = Record(
token_ids=token_ids, # token_ids=token_ids,
text_type_ids=text_type_ids, # text_type_ids=text_type_ids,
position_ids=position_ids) # position_ids=position_ids)
return record return [token_ids, text_type_ids, position_ids]
def batch_reader(examples, batch_size, in_tokens, phase): def batch_reader(self, examples, batch_size, in_tokens, phase):
batch, total_token_num, max_len = [], 0, 0 batch = []
total_token_num = 0
for e in examples: for e in examples:
token_ids, sent_ids, pos_ids = _convert_example_to_record(e, self.max_seq_len, self.tokenizer) parsed_line = self._convert_example_to_record(e, self.max_seq_len, self.tokenizer)
max_len = max(max_len, len(token_ids)) to_append = len(batch) < batch_size
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append: if to_append:
batch.append(parsed_line) batch.append(parsed_line)
total_token_num += len(token_ids) total_token_num += len(parsed_line[0])
else: else:
yield batch, total_token_num yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len( batch = [parsed_line]
token_ids), len(token_ids) total_token_num = len(parsed_line[0])
if len(batch) > 0 and phase == 'pred': if len(batch) > 0 and phase == 'pred':
yield batch, total_token_num yield batch, total_token_num
...@@ -426,17 +419,17 @@ class MaskLMReader(BaseReader): ...@@ -426,17 +419,17 @@ class MaskLMReader(BaseReader):
np.random.shuffle(examples) np.random.shuffle(examples)
all_dev_batches = [] all_dev_batches = []
for batch_data, total_token_num in batch_reader(examples, for batch_data, num_tokens in self.batch_reader(examples,
self.batch_size, self.in_tokens, phase=phase): batch_size, self.in_tokens, phase=phase):
batch_data = prepare_batch_data( batch_data = prepare_batch_data(
batch_data, batch_data,
total_token_num, num_tokens,
voc_size=self.voc_size, voc_size=len(self.vocab),
pad_id=self.pad_id, pad_id=self.pad_id,
cls_id=self.cls_id, cls_id=self.cls_id,
sep_id=self.sep_id, sep_id=self.sep_id,
mask_id=self.mask_id, mask_id=self.mask_id,
max_len=self.max_seq_len, # max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False)
......
...@@ -70,7 +70,11 @@ class TaskInstance(object): ...@@ -70,7 +70,11 @@ class TaskInstance(object):
def build_task_layer(self, net_inputs, phase): def build_task_layer(self, net_inputs, phase):
output_vars = self._task_layer[phase].build(net_inputs) output_vars = self._task_layer[phase].build(net_inputs)
if phase == 'pred': if phase == 'pred':
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
else:
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
return output_vars return output_vars
def postprocess(self, rt_outputs, phase): def postprocess(self, rt_outputs, phase):
...@@ -234,8 +238,6 @@ class TaskInstance(object): ...@@ -234,8 +238,6 @@ class TaskInstance(object):
self._cur_train_step = 1 self._cur_train_step = 1
if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
self._train_finish = True self._train_finish = True
print(self._name+': train finished!')
self.save()
# fluid.io.save_inference_model(self._save_infermodel_path, ) # fluid.io.save_inference_model(self._save_infermodel_path, )
@property @property
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm from paddlepalm.interface import task_paradigm
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.backbone.utils.transformer import pre_process_layer
class TaskParadigm(task_paradigm): class TaskParadigm(task_paradigm):
''' '''
...@@ -23,6 +24,7 @@ class TaskParadigm(task_paradigm): ...@@ -23,6 +24,7 @@ class TaskParadigm(task_paradigm):
''' '''
def __init__(self, config, phase, backbone_config=None): def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train' self._is_training = phase == 'train'
self._emb_size = backbone_config['hidden_size']
self._hidden_size = backbone_config['hidden_size'] self._hidden_size = backbone_config['hidden_size']
self._vocab_size = backbone_config['vocab_size'] self._vocab_size = backbone_config['vocab_size']
self._hidden_act = backbone_config['hidden_act'] self._hidden_act = backbone_config['hidden_act']
...@@ -30,11 +32,14 @@ class TaskParadigm(task_paradigm): ...@@ -30,11 +32,14 @@ class TaskParadigm(task_paradigm):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: reader = {
reader = {"label_ids": [[-1, 1], 'int64']} "mask_label": [[-1, 1], 'int64'],
else: "mask_pos": [[-1, 1], 'int64']}
reader = {} if not self._is_training:
bb = {"encoder_outputs": [[-1, self._hidden_size], 'float32']} del reader['mask_label']
bb = {
"encoder_outputs": [[-1, -1, self._hidden_size], 'float32'],
"embedding_table": [[-1, self._vocab_size, self._emb_size], 'float32']}
return {'reader': reader, 'backbone': bb} return {'reader': reader, 'backbone': bb}
@property @property
...@@ -42,12 +47,13 @@ class TaskParadigm(task_paradigm): ...@@ -42,12 +47,13 @@ class TaskParadigm(task_paradigm):
if self._is_training: if self._is_training:
return {"loss": [[1], 'float32']} return {"loss": [[1], 'float32']}
else: else:
return {"logits": [[-1, 1], 'float32']} return {"logits": [[-1], 'float32']}
def build(self, inputs): def build(self, inputs):
mask_label = inputs["reader"]["mask_label"] if self._is_training:
mask_label = inputs["reader"]["mask_label"]
mask_pos = inputs["reader"]["mask_pos"] mask_pos = inputs["reader"]["mask_pos"]
word_emb = inputs["backbone"]["word_embedding"] word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"] enc_out = inputs["backbone"]["encoder_outputs"]
emb_size = word_emb.shape[-1] emb_size = word_emb.shape[-1]
...@@ -62,7 +68,6 @@ class TaskParadigm(task_paradigm): ...@@ -62,7 +68,6 @@ class TaskParadigm(task_paradigm):
# extract masked tokens' feature # extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos) mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
num_seqs = fluid.layers.fill_constant(shape=[1], value=512, dtype='int64')
# transform: fc # transform: fc
mask_trans_feat = fluid.layers.fc( mask_trans_feat = fluid.layers.fc(
...@@ -99,13 +104,12 @@ class TaskParadigm(task_paradigm): ...@@ -99,13 +104,12 @@ class TaskParadigm(task_paradigm):
attr=mask_lm_out_bias_attr, attr=mask_lm_out_bias_attr,
is_bias=True) is_bias=True)
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
loss = fluid.layers.mean(mask_lm_loss)
if self._is_training: if self._is_training:
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
loss = fluid.layers.mean(mask_lm_loss)
return {'loss': loss} return {'loss': loss}
else: else:
return None return {'logits': fc_out}
...@@ -22,7 +22,7 @@ from paddle import fluid ...@@ -22,7 +22,7 @@ from paddle import fluid
from paddle.fluid import layers from paddle.fluid import layers
def _check_and_adapt_shape_dtype(rt_val, attr): def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
if not isinstance(rt_val, np.ndarray): if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val) rt_val = np.array(rt_val)
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
...@@ -30,12 +30,12 @@ def _check_and_adapt_shape_dtype(rt_val, attr): ...@@ -30,12 +30,12 @@ def _check_and_adapt_shape_dtype(rt_val, attr):
rt_val = rt_val.astype('float32') rt_val = rt_val.astype('float32')
shape, dtype = attr shape, dtype = attr
assert rt_val.dtype == np.dtype(dtype), "yielded data type not consistent with attr settings." assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
assert len(shape) == rt_val.ndim, "yielded data rank(ndim) not consistent with attr settings." assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
for rt, exp in zip(rt_val.shape, shape): for rt, exp in zip(rt_val.shape, shape):
if exp is None or exp < 0: if exp is None or exp < 0:
continue continue
assert rt == exp, "yielded data shape is not consistent with attr settings.\nExpected:{}\nActual:{}".format(exp, rt) assert rt == exp, "yielded data shape is not consistent with attr settings.Expected:{}Actual:{}".format(exp, rt)
return rt_val return rt_val
...@@ -107,7 +107,7 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p ...@@ -107,7 +107,7 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p
return iterator return iterator
def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0, batch_size=None): def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0):
""" """
joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
""" """
...@@ -130,12 +130,12 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -130,12 +130,12 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if outname in outname_to_pos: if outname in outname_to_pos:
idx = outname_to_pos[outname] idx = outname_to_pos[outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx]) val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
results[idx] = val results[idx] = val
if task_outname in outname_to_pos: if task_outname in outname_to_pos:
idx = outname_to_pos[task_outname] idx = outname_to_pos[task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx]) val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val results[idx] = val
fake_batch = results fake_batch = results
...@@ -153,7 +153,6 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -153,7 +153,6 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
results[0] = task_id_tensor results[0] = task_id_tensor
for i in range(dev_count): for i in range(dev_count):
# results = _zero_batch(joint_shape_and_dtypes, batch_size=batch_size)
results[0] = task_id_tensor results[0] = task_id_tensor
if id in outbuf: if id in outbuf:
outputs = outbuf[id] outputs = outbuf[id]
...@@ -171,14 +170,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -171,14 +170,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
idx = outname_to_pos[outname] idx = outname_to_pos[outname]
if v > 0: if v > 0:
print(outname + ' is insert in idx ' + str(idx)) print(outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx]) val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
results[idx] = val results[idx] = val
if task_outname in outname_to_pos: if task_outname in outname_to_pos:
idx = outname_to_pos[task_outname] idx = outname_to_pos[task_outname]
if v > 0: if v > 0:
print(task_outname + ' is insert in idx ' + str(idx)) print(task_outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx]) val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val results[idx] = val
if v > 0: if v > 0:
......
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python demo2.py python demo2.py
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册