diff --git a/.gitignore b/.gitignore index 4c0ad57a6a1158a8565bead716a45db79e9c1e6a..c84287abd6fd69747a5ba7e39177cf1b95db5330 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ __pycache__ pretrain_model output_model +mrqa_output +*.log diff --git a/demo.py b/demo1.py similarity index 55% rename from demo.py rename to demo1.py index af854f772e52206a6713c22b7cff1b753bd11051..505a6e0a68580cb888e077d164b7e9d370afe4cf 100644 --- a/demo.py +++ b/demo1.py @@ -1,10 +1,10 @@ import paddlepalm as palm if __name__ == '__main__': - controller = palm.Controller('config.yaml', task_dir='task_instance') + controller = palm.Controller('demo1_config.yaml', task_dir='demo1_tasks') controller.load_pretrain('pretrain_model/ernie/params') controller.train() - controller = palm.Controller(config='config.yaml', task_dir='task_instance', for_train=False) + controller = palm.Controller(config='demo1_config.yaml', task_dir='demo1_tasks', for_train=False) controller.pred('mrqa', inference_model_dir='output_model/firstrun/infer_model') diff --git a/demo1_config.yaml b/demo1_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3033cc0b418e281ae7d92579167166c084eb8e30 --- /dev/null +++ b/demo1_config.yaml @@ -0,0 +1,20 @@ +task_instance: "mrqa" +target_tag: 1 +mix_ratio: 1.0 + +save_path: "output_model/firstrun" + +backbone: "ernie" +backbone_config_path: "pretrain_model/ernie/ernie_config.json" + +vocab_path: "pretrain_model/ernie/vocab.txt" +do_lower_case: True +max_seq_len: 512 + +batch_size: 5 +num_epochs: 2 +optimizer: "adam" +learning_rate: 3e-5 +warmup_proportion: 0.1 +weight_decay: 0.1 + diff --git a/task_instance/mrqa.yaml b/demo1_tasks/mrqa.yaml similarity index 100% rename from task_instance/mrqa.yaml rename to demo1_tasks/mrqa.yaml diff --git a/demo2.py b/demo2.py new file mode 100644 index 0000000000000000000000000000000000000000..75a6e77fe01dcd45d0c677dab79eeeb98aca2a10 --- /dev/null +++ b/demo2.py @@ -0,0 +1,10 @@ +import paddlepalm as palm + +if __name__ == '__main__': + controller = palm.Controller('demo2_config.yaml', task_dir='demo2_tasks') + controller.load_pretrain('pretrain_model/ernie/params') + controller.train() + + controller = palm.Controller(config='demo2_config.yaml', task_dir='demo2_tasks', for_train=False) + controller.pred('mrqa', inference_model_dir='output_model/secondrun/infer_model') + diff --git a/config.yaml b/demo2_config.yaml similarity index 90% rename from config.yaml rename to demo2_config.yaml index 045f70c47aabcd76cc2115cc340933a3f2dc626b..23e9d86fcb527090470316a4841fb4a91969f8ab 100644 --- a/config.yaml +++ b/demo2_config.yaml @@ -2,7 +2,7 @@ task_instance: "mrqa, match4mrqa" target_tag: 1, 0 mix_ratio: 1.0, 0.5 -save_path: "output_model/firstrun" +save_path: "output_model/secondrun" backbone: "ernie" backbone_config_path: "pretrain_model/ernie/ernie_config.json" diff --git a/task_instance/match4mrqa.yaml b/demo2_tasks/match4mrqa.yaml similarity index 100% rename from task_instance/match4mrqa.yaml rename to demo2_tasks/match4mrqa.yaml diff --git a/demo2_tasks/mrqa.yaml b/demo2_tasks/mrqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62b9d4ee3a0171f2a88d5b9acad63d983842b6e7 --- /dev/null +++ b/demo2_tasks/mrqa.yaml @@ -0,0 +1,11 @@ +train_file: data/mrqa/mrqa-combined.train.raw.json +pred_file: data/mrqa/mrqa-combined.dev.raw.json +pred_output_path: 'mrqa_output' +reader: mrc4ernie +paradigm: mrc +doc_stride: 128 +max_query_len: 64 +max_answer_len: 30 +n_best_size: 20 +null_score_diff_threshold: 0.0 +verbose: False diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 450d960b291d9e5777f306f7a5f9685ec3fac706..35500244cf27011bf03c2207d7c7f88280c92d7d 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -422,7 +422,7 @@ class Controller(object): prefixes.append(inst.name) mrs.append(inst.mix_ratio) - joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE) + joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, batch_size=main_conf['batch_size']) input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] @@ -488,10 +488,9 @@ class Controller(object): bb_fetches = {k: v.name for k,v in bb_output_vars.items()} task_fetches = {k: v.name for k,v in task_output_vars.items()} - old = len(bb_fetches)+len(task_fetches) # for debug - fetches = bb_fetches.copy() - fetches.update(task_fetches) - assert len(fetches) == old # for debug + # fetches = bb_fetches.copy() # 注意!框架在多卡时无法fetch变长维度的tensor,这里加入bb的out后会挂 + # fetches.update(task_fetches) + fetches = task_fetches fetches['__task_id'] = net_inputs['__task_id'].name # compute loss @@ -505,7 +504,8 @@ class Controller(object): num_examples = main_reader.num_examples for inst in instances: max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * num_examples) // main_conf['batch_size'] // dev_count - print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) + if inst.is_target: + print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count inst.expected_train_steps = max_train_steps @@ -622,12 +622,11 @@ class Controller(object): epoch = 0 time_begin = time.time() backbone_buffer = [] - task_buffer = [[]] * num_instances while not train_finish(): rt_outputs = self.exe.run(train_program, fetch_list=fetch_list) rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() - assert (not isinstance(rt_task_id, list)) or len(set(rt_task_id)) == 1 + assert (not isinstance(rt_task_id, list)) or len(set(rt_task_id)) == 1, rt_task_id rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id cur_task = instances[rt_task_id] @@ -635,8 +634,7 @@ class Controller(object): backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - temp = instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) - task_buffer[rt_task_id].append(temp) + instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) global_step += 1 # if cur_task.is_target: diff --git a/paddlepalm/reader/mrc4ernie.py b/paddlepalm/reader/mrc4ernie.py index b79b62351a680e21b35babea82e8c1bf2ce4b509..0c550dfdb68cbaf442ceef0e1187db509e022139 100644 --- a/paddlepalm/reader/mrc4ernie.py +++ b/paddlepalm/reader/mrc4ernie.py @@ -30,6 +30,7 @@ class Reader(reader): max_seq_len=config['max_seq_len'], do_lower_case=config.get('do_lower_case', False), tokenizer='FullTokenizer', + for_cn=config.get('for_cn', False), doc_stride=config['doc_stride'], max_query_length=config['max_query_len'], random_seed=config.get('seed', None)) diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task_paradigm/match.py index e2c94075689b7b833b8476d8bd73b6b02cfdb211..58bbf35d56d16ab2b0a2ebcd9c653b9e1c666087 100644 --- a/paddlepalm/task_paradigm/match.py +++ b/paddlepalm/task_paradigm/match.py @@ -42,7 +42,8 @@ class TaskParadigm(task_paradigm): return {"logits": [[-1, 1], 'float32']} def build(self, inputs): - labels = inputs["reader"]["label_ids"] + if self._is_training: + labels = inputs["reader"]["label_ids"] cls_feats = inputs["backbone"]["sentence_pair_embedding"] cls_feats = fluid.layers.dropout( @@ -58,11 +59,11 @@ class TaskParadigm(task_paradigm): bias_attr=fluid.ParamAttr( name="cls_out_b", initializer=fluid.initializer.Constant(0.))) - ce_loss, probs = fluid.layers.softmax_with_cross_entropy( - logits=logits, label=labels, return_softmax=True) - loss = fluid.layers.mean(x=ce_loss) if self._is_training: + ce_loss, probs = fluid.layers.softmax_with_cross_entropy( + logits=logits, label=labels, return_softmax=True) + loss = fluid.layers.mean(x=ce_loss) return {'loss': loss} else: return {'logits': logits} diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task_paradigm/mrc.py index 4889f24febdd4b95458a1c784127e49935ec0898..1d3642ad87328dd3c5b874cbad4eb2662e732538 100644 --- a/paddlepalm/task_paradigm/mrc.py +++ b/paddlepalm/task_paradigm/mrc.py @@ -65,9 +65,7 @@ class TaskParadigm(task_paradigm): @property def outputs_attr(self): if self._is_training: - return {'start_logits': [[-1, -1, 1], 'float32'], - 'end_logits': [[-1, -1, 1], 'float32'], - 'loss': [[1], 'float32']} + return {'loss': [[1], 'float32']} else: return {'start_logits': [[-1, -1, 1], 'float32'], 'end_logits': [[-1, -1, 1], 'float32'], @@ -106,16 +104,14 @@ class TaskParadigm(task_paradigm): start_loss = _compute_single_loss(start_logits, start_positions) end_loss = _compute_single_loss(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2.0 - return {'start_logits': start_logits, - 'end_logits': end_logits, - 'loss': total_loss} + return {'loss': total_loss} else: return {'start_logits': start_logits, 'end_logits': end_logits, 'unique_ids': unique_id} - def postprocess(self, rt_outputs): + def postprocess(self, rt_outputs): """this func will be called after each step(batch) of training/evaluating/predicting process.""" if not self._is_training: unique_ids = np.squeeze(rt_outputs['unique_ids'], -1) diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 4cd4d32e16717ef8d87becf77dcaf1bfad94ed66..0124e63510e1b4a0ac6486980add758ffc8cacfa 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -48,6 +48,20 @@ def _zero_batch(attrs): return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs] +def _zero_batch_x(attrs, batch_size): + pos_attrs = [] + for shape, dtype in attrs: + # pos_shape = [size if size and size > 0 else 5 for size in shape] + pos_shape = [size for size in shape] + if pos_shape[0] == -1: + pos_shape[0] = batch_size + if pos_shape[1] == -1: + pos_shape[1] = 512 # max seq len + pos_attrs.append([pos_shape, dtype]) + + return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs] + + def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n_prefetch=1): inputs = [] ret = {} @@ -92,10 +106,11 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p return iterator -def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0): +def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0, batch_size=None): """ joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 """ + task_ids = range(len(iterators)) weights = [mr / float(sum(mrs)) for mr in mrs] if not keep_one_task: @@ -129,7 +144,6 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype v = verbose while True: id = np.random.choice(task_ids, p=weights) - # results = _zero_batch(joint_shape_and_dtypes) results = fake_batch if v > 0: print('----- debug joint iterator -----') @@ -138,6 +152,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype results[0] = task_id_tensor for i in range(dev_count): + # results = _zero_batch(joint_shape_and_dtypes, batch_size=batch_size) + # results[0] = task_id_tensor if id in outbuf: outputs = outbuf[id] del outbuf[id] diff --git a/run_demo.sh b/run_demo1.sh similarity index 62% rename from run_demo.sh rename to run_demo1.sh index e9cddc1f21f22c33f145cc4a2b21174a6d94d6c0..d6d30cae1ee04df82d20d1195ea039f865443f1a 100755 --- a/run_demo.sh +++ b/run_demo1.sh @@ -1,6 +1,6 @@ -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0,1,2,3 export FLAGS_fraction_of_gpu_memory_to_use=0.1 export FLAGS_eager_delete_tensor_gb=0 -python demo.py +python demo1.py diff --git a/run_demo2.sh b/run_demo2.sh new file mode 100755 index 0000000000000000000000000000000000000000..ca69529f8b0e9dad12e862eab819bc06c83f293d --- /dev/null +++ b/run_demo2.sh @@ -0,0 +1,6 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_fraction_of_gpu_memory_to_use=0.1 +export FLAGS_eager_delete_tensor_gb=0 + +python demo2.py + diff --git a/download_pretrain_backbone.sh b/script/download_pretrain_backbone.sh similarity index 100% rename from download_pretrain_backbone.sh rename to script/download_pretrain_backbone.sh diff --git a/task_instance/mlm4mrqa.yaml b/task_instance/mlm4mrqa.yaml deleted file mode 100644 index 1b44610d7ef4eb059c13a42b4cb11b2c1c13c886..0000000000000000000000000000000000000000 --- a/task_instance/mlm4mrqa.yaml +++ /dev/null @@ -1,5 +0,0 @@ -train_file: "data/mlm4mrqa" -mix_ratio: 0.4 -batch_size: 4 -in_tokens: False -generate_neg_sample: False