From f567603f501591353d4f6b586a7189f774eae766 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Wed, 30 Oct 2019 12:24:09 +0800 Subject: [PATCH] fix bugs --- config_demo1.yaml | 4 +- config_demo2.yaml | 10 +-- config_demo3.py => config_demo3.yaml | 10 +-- demo3.py | 4 +- demo3_tasks/mlm4mrqa.yaml | 3 + nohup.out | 110 --------------------------- paddlepalm/mtl_controller.py | 8 +- paddlepalm/reader/cls.py | 4 +- paddlepalm/reader/mlm.py | 3 +- paddlepalm/task_paradigm/mlm.py | 11 ++- paddlepalm/utils/reader_helper.py | 56 +++++++++++--- run_demo1.sh | 2 +- run_demo2.sh | 9 +-- run_demo3.sh | 2 +- 14 files changed, 87 insertions(+), 149 deletions(-) rename config_demo3.py => config_demo3.yaml (53%) create mode 100644 demo3_tasks/mlm4mrqa.yaml delete mode 100644 nohup.out diff --git a/config_demo1.yaml b/config_demo1.yaml index 384050a..aca88e5 100644 --- a/config_demo1.yaml +++ b/config_demo1.yaml @@ -5,8 +5,8 @@ save_path: "output_model/firstrun" backbone: "bert" backbone_config_path: "pretrain_model/bert/bert_config.json" -batch_size: 5 -num_epochs: 3 +batch_size: 4 +num_epochs: 2 optimizer: "adam" learning_rate: 3e-5 warmup_proportion: 0.1 diff --git a/config_demo2.yaml b/config_demo2.yaml index a033d5b..7b55af5 100644 --- a/config_demo2.yaml +++ b/config_demo2.yaml @@ -1,6 +1,6 @@ -task_instance: "mrqa, mlm4mrqa, match4mrqa" -target_tag: 1, 0, 0 -mix_ratio: 0.5, 1.0, 0.5 +task_instance: "mrqa, match4mrqa" +target_tag: 1, 0 +mix_ratio: 0.5, 0.5 save_path: "output_model/secondrun" @@ -11,8 +11,8 @@ vocab_path: "pretrain_model/ernie/vocab.txt" do_lower_case: True max_seq_len: 512 -batch_size: 5 -num_epochs: 5 +batch_size: 4 +num_epochs: 2 optimizer: "adam" learning_rate: 3e-5 warmup_proportion: 0.1 diff --git a/config_demo3.py b/config_demo3.yaml similarity index 53% rename from config_demo3.py rename to config_demo3.yaml index 08f7025..7731a8d 100644 --- a/config_demo3.py +++ b/config_demo3.yaml @@ -1,16 +1,16 @@ -task_instance: "mrqa" +task_instance: "mlm4mrqa" save_path: "output_model/firstrun" -backbone: "bert" -backbone_config_path: "pretrain_model/bert/bert_config.json" +backbone: "ernie" +backbone_config_path: "pretrain_model/ernie/ernie_config.json" -vocab_path: "pretrain_model/bert/vocab.txt" +vocab_path: "pretrain_model/ernie/vocab.txt" do_lower_case: True max_seq_len: 512 batch_size: 5 -num_epochs: 3 +num_epochs: 100 optimizer: "adam" learning_rate: 3e-5 warmup_proportion: 0.1 diff --git a/demo3.py b/demo3.py index 55413ec..59d3242 100644 --- a/demo3.py +++ b/demo3.py @@ -5,6 +5,6 @@ if __name__ == '__main__': controller.load_pretrain('pretrain_model/ernie/params') controller.train() - controller = palm.Controller(config='config_demo3.yaml', task_dir='demo3_tasks', for_train=False) - controller.pred('cls4mrqa', inference_model_dir='output_model/thirdrun/infer_model') + # controller = palm.Controller(config='config_demo3.yaml', task_dir='demo3_tasks', for_train=False) + # controller.pred('cls4mrqa', inference_model_dir='output_model/thirdrun/infer_model') diff --git a/demo3_tasks/mlm4mrqa.yaml b/demo3_tasks/mlm4mrqa.yaml new file mode 100644 index 0000000..c24296e --- /dev/null +++ b/demo3_tasks/mlm4mrqa.yaml @@ -0,0 +1,3 @@ +train_file: "data/mlm4mrqa/train.tsv" +reader: mlm +paradigm: mlm diff --git a/nohup.out b/nohup.out deleted file mode 100644 index a90603a..0000000 --- a/nohup.out +++ /dev/null @@ -1,110 +0,0 @@ -W1028 21:51:59.319365 9630 device_context.cc:235] Please NOTE: device: 0, CUDA Capability: 61, Driver API Version: 10.1, Runtime API Version: 9.0 -W1028 21:51:59.323333 9630 device_context.cc:243] device: 0, cuDNN Version: 7.3. -I1028 21:52:26.817137 9630 parallel_executor.cc:421] The number of CUDAPlace, which is used in ParallelExecutor, is 8. And the Program will be copied 8 copies -W1028 21:52:41.982228 9630 fuse_all_reduce_op_pass.cc:72] Find all_reduce operators: 401. To make the speed faster, some all_reduce ops are fused during training, after fusion, the number of all_reduce ops is 255. -I1028 21:52:42.243458 9630 build_strategy.cc:363] SeqOnlyAllReduceOps:0, num_trainers:1 -I1028 21:53:14.242537 9630 parallel_executor.cc:285] Inplace strategy is enabled, when build_strategy.enable_inplace = True -I1028 21:53:16.313246 9630 parallel_executor.cc:368] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0 -/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/executor.py:774: UserWarning: The following exception is not an EOF exception. - "The following exception is not an EOF exception.") -Traceback (most recent call last): - File "demo2.py", line 6, in - controller.train() - File "/home/ssd7/yiming/release/PALM/paddlepalm/mtl_controller.py", line 669, in train - fluid.io.save_persistables(self.exe, save_path, saver_program) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/io.py", line 571, in save_persistables - filename=filename) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/io.py", line 216, in save_vars - filename=filename) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/io.py", line 256, in save_vars - executor.run(save_program) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/executor.py", line 775, in run - six.reraise(*sys.exc_info()) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/executor.py", line 770, in run - use_program_cache=use_program_cache) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/executor.py", line 817, in _run_impl - use_program_cache=use_program_cache) - File "/home/zhangyiming/env-bert/lib/python2.7/site-packages/paddle/fluid/executor.py", line 894, in _run_program - fetch_var_name) -paddle.fluid.core_avx.EnforceNotMet: - --------------------------------------------- -C++ Call Stacks (More useful to developers): --------------------------------------------- -0 std::string paddle::platform::GetTraceBackString(char const*&&, char const*, int) -1 paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) -2 paddle::operators::SaveOpKernel::SaveLodTensor(paddle::framework::ExecutionContext const&, boost::variant const&, paddle::framework::Variable const*) const -3 paddle::operators::SaveOpKernel::Compute(paddle::framework::ExecutionContext const&) const -4 std::_Function_handler, paddle::operators::SaveOpKernel, paddle::operators::SaveOpKernel> demo2.log -W1029 10:43:27.495725 32687 device_context.cc:235] Please NOTE: device: 0, CUDA Capability: 61, Driver API Version: 10.1, Runtime API Version: 9.0 -W1029 10:43:27.500324 32687 device_context.cc:243] device: 0, cuDNN Version: 7.3. -I1029 10:43:41.409127 32687 parallel_executor.cc:421] The number of CUDAPlace, which is used in ParallelExecutor, is 8. And the Program will be copied 8 copies -W1029 10:44:03.299010 32687 fuse_all_reduce_op_pass.cc:72] Find all_reduce operators: 401. To make the speed faster, some all_reduce ops are fused during training, after fusion, the number of all_reduce ops is 255. -I1029 10:44:03.584228 32687 build_strategy.cc:363] SeqOnlyAllReduceOps:0, num_trainers:1 -I1029 10:44:39.690382 32687 parallel_executor.cc:285] Inplace strategy is enabled, when build_strategy.enable_inplace = True -I1029 10:44:42.244774 32687 parallel_executor.cc:368] Garbage collection strategy is enabled, when FLAGS_eager_delete_tensor_gb = 0 -W1029 10:48:20.253201 32687 init.cc:212] *** Aborted at 1572317300 (unix time) try "date -d @1572317300" if you are using GNU date *** -W1029 10:48:20.255347 32687 init.cc:212] PC: @ 0x0 (unknown) -W1029 10:48:20.255458 32687 init.cc:212] *** SIGTERM (@0x1f80000785a) received by PID 32687 (TID 0x7f0f71d25700) from PID 30810; stack trace: *** -W1029 10:48:20.257107 32687 init.cc:212] @ 0x7f0f714ef160 (unknown) -W1029 10:48:20.258708 32687 init.cc:212] @ 0x7f0f714eb3cc __pthread_cond_wait -W1029 10:48:20.259734 32687 init.cc:212] @ 0x7f0f249d33cc std::condition_variable::wait() -W1029 10:48:20.263964 32687 init.cc:212] @ 0x7f0f008e990d paddle::framework::details::FastThreadedSSAGraphExecutor::Run() -W1029 10:48:20.265229 32687 init.cc:212] @ 0x7f0f0084a6a7 _ZNSt17_Function_handlerIFvvEZN6paddle9framework7details29ScopeBufferedSSAGraphExecutor3RunERKSt6vectorISsSaISsEEEUlvE_E9_M_invokeERKSt9_Any_data -W1029 10:48:20.268503 32687 init.cc:212] @ 0x7f0f0084f4bf paddle::framework::details::ScopeBufferedMonitor::Apply() -W1029 10:48:20.270135 32687 init.cc:212] @ 0x7f0f0084ae86 paddle::framework::details::ScopeBufferedSSAGraphExecutor::Run() -W1029 10:48:20.272866 32687 init.cc:212] @ 0x7f0efe5ed038 paddle::framework::ParallelExecutor::Run() -W1029 10:48:20.273551 32687 init.cc:212] @ 0x7f0efe3d0e78 _ZZN8pybind1112cpp_function10initializeIZN6paddle6pybindL22pybind11_init_core_avxERNS_6moduleEEUlRNS2_9framework16ParallelExecutorERKSt6vectorISsSaISsEEE188_S9_INS6_9LoDTensorESaISF_EEIS8_SD_EINS_4nameENS_9is_methodENS_7siblingEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE1_4_FUNESY_ -W1029 10:48:20.274988 32687 init.cc:212] @ 0x7f0efe41af56 pybind11::cpp_function::dispatcher() -W1029 10:48:20.276706 32687 init.cc:212] @ 0x7f0f71808cc8 PyEval_EvalFrameEx -W1029 10:48:20.278395 32687 init.cc:212] @ 0x7f0f7180b35d PyEval_EvalCodeEx -W1029 10:48:20.280076 32687 init.cc:212] @ 0x7f0f71808d50 PyEval_EvalFrameEx -W1029 10:48:20.281765 32687 init.cc:212] @ 0x7f0f7180b35d PyEval_EvalCodeEx -W1029 10:48:20.283442 32687 init.cc:212] @ 0x7f0f71808d50 PyEval_EvalFrameEx -W1029 10:48:20.285133 32687 init.cc:212] @ 0x7f0f7180b35d PyEval_EvalCodeEx -W1029 10:48:20.286808 32687 init.cc:212] @ 0x7f0f71808d50 PyEval_EvalFrameEx -W1029 10:48:20.288502 32687 init.cc:212] @ 0x7f0f7180b35d PyEval_EvalCodeEx -W1029 10:48:20.290176 32687 init.cc:212] @ 0x7f0f71808d50 PyEval_EvalFrameEx -W1029 10:48:20.291870 32687 init.cc:212] @ 0x7f0f7180b35d PyEval_EvalCodeEx -W1029 10:48:20.293542 32687 init.cc:212] @ 0x7f0f7180b492 PyEval_EvalCode -W1029 10:48:20.295228 32687 init.cc:212] @ 0x7f0f718351a2 PyRun_FileExFlags -W1029 10:48:20.296922 32687 init.cc:212] @ 0x7f0f71836539 PyRun_SimpleFileExFlags -W1029 10:48:20.298590 32687 init.cc:212] @ 0x7f0f7184c1bd Py_Main -W1029 10:48:20.300307 32687 init.cc:212] @ 0x7f0f70a49bd5 __libc_start_main -W1029 10:48:20.300364 32687 init.cc:212] @ 0x4007a1 (unknown) -W1029 10:48:20.302006 32687 init.cc:212] @ 0x0 (unknown) diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index b4716d5..bd8456a 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -557,7 +557,7 @@ class Controller(object): inst.task_layer['pred'] = pred_parad pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], - insert_taskid=False) + insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_prog = inst.load(infer_model_path) # pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() @@ -664,9 +664,9 @@ class Controller(object): "step_" + str(global_step)) fluid.io.save_persistables(self.exe, save_path, saver_program) - save_path = os.path.join(main_conf['save_path'], - "step_" + str(global_step) + "_final") - fluid.io.save_persistables(self.exe, save_path, saver_program) + # save_path = os.path.join(main_conf['save_path'], + # "step_" + str(global_step) + "_final") + # fluid.io.save_persistables(self.exe, save_path, saver_program) def pred(self, task_instance, inference_model_dir=None): if self._for_train: diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 4c37e67..286226d 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -42,7 +42,7 @@ class Reader(reader): self._input_file = config['train_file'] self._num_epochs = None # 防止iteartor终止 self._shuffle = config.get('shuffle', False) - self._shuffle_buffer = config.get('shuffle_buffer', 5000) + # self._shuffle_buffer = config.get('shuffle_buffer', 5000) elif phase == 'eval': self._input_file = config['dev_file'] self._num_epochs = 1 @@ -56,7 +56,7 @@ class Reader(reader): self._phase = phase # self._batch_size = - self._print_first_n = config.get('print_first_n', 1) + self._print_first_n = config.get('print_first_n', 0) @property diff --git a/paddlepalm/reader/mlm.py b/paddlepalm/reader/mlm.py index bc9f190..3739bea 100644 --- a/paddlepalm/reader/mlm.py +++ b/paddlepalm/reader/mlm.py @@ -66,7 +66,7 @@ class Reader(reader): "input_mask": [[-1, -1, 1], 'float32'], "task_ids": [[-1, -1, 1], 'int64'], "mask_label": [[-1, 1], 'int64'], - "mask_pos": [[-1, 1], 'int64'] + "mask_pos": [[-1, 1], 'int64'], } @@ -79,6 +79,7 @@ class Reader(reader): names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask', 'task_ids', 'mask_label', 'mask_pos'] outputs = {n: i for n,i in zip(names, x)} + # outputs['batchsize_x_seqlen'] = [self._batch_size * len(outputs['token_ids'][0]) - 1] return outputs for batch in self._data_generator(): diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task_paradigm/mlm.py index 53c2866..08a4e42 100644 --- a/paddlepalm/task_paradigm/mlm.py +++ b/paddlepalm/task_paradigm/mlm.py @@ -34,9 +34,11 @@ class TaskParadigm(task_paradigm): def inputs_attrs(self): reader = { "mask_label": [[-1, 1], 'int64'], + "batchsize_x_seqlen": [[1], 'int64'], "mask_pos": [[-1, 1], 'int64']} if not self._is_training: del reader['mask_label'] + del reader['batchsize_x_seqlen'] bb = { "encoder_outputs": [[-1, -1, self._hidden_size], 'float32'], "embedding_table": [[-1, self._vocab_size, self._emb_size], 'float32']} @@ -52,6 +54,8 @@ class TaskParadigm(task_paradigm): def build(self, inputs): if self._is_training: mask_label = inputs["reader"]["mask_label"] + # 多任务学习时才需要引入这个,防止其他run其他任务时导致seqlen过小,gather超范围 + batchsize_x_seqlen = inputs["reader"]["batchsize_x_seqlen"] mask_pos = inputs["reader"]["mask_pos"] word_emb = inputs["backbone"]["embedding_table"] enc_out = inputs["backbone"]["encoder_outputs"] @@ -61,7 +65,12 @@ class TaskParadigm(task_paradigm): _param_initializer = fluid.initializer.TruncatedNormal( scale=self._initializer_range) - mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32') + if self._is_training: + # 多任务训练时才需要引入这个,防止其他run其他任务时导致seqlen过小,gather超范围 + #mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32') + mask_pos = fluid.layers.elementwise_min(mask_pos, batchsize_x_seqlen) + + #print(fluid.default_main_program().blocks[0].vars) reshaped_emb_out = fluid.layers.reshape( x=enc_out, shape=[-1, emb_size]) diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index c9e5f88..6f66c75 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -143,6 +143,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype def iterator(): v = verbose + has_show_warn = False while True: id = np.random.choice(task_ids, p=weights) results = fake_batch @@ -150,16 +151,37 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype print('----- debug joint iterator -----') print('sampled task id: '+str(id)) task_id_tensor = np.array([[id]]).astype("int64") - results[0] = task_id_tensor + # results[0] = task_id_tensor for i in range(dev_count): - results[0] = task_id_tensor + + # 这两个应该是等价的 + # results[0] = task_id_tensor + results[outname_to_pos['__task_id']] = task_id_tensor + assert outname_to_pos['__task_id'] == 0 + if id in outbuf: outputs = outbuf[id] del outbuf[id] else: outputs = next(iterators[id]) # dict type + # if 'token_ids' in outputs: + # val1 = len(outputs['token_ids']) + # val = _check_and_adapt_shape_dtype([val1], [[1], 'int64']) + # results[outname_to_pos['batch_size']] = val + + # val2 = len(outputs['token_ids'][0]) + # val = _check_and_adapt_shape_dtype([val2], [[1], 'int64']) + # results[outname_to_pos['seqlen']] = val + + # val = _check_and_adapt_shape_dtype([val1*val2], [[1], 'int64']) + # results[outname_to_pos['batchsize_x_seqlen']] = val + # else: + # if not has_show_warn: + # print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)') + # has_show_warn = True + prefix = iterator_prefixes[id] for outname, val in outputs.items(): if v > 0: @@ -192,7 +214,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype return iterator -def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True): +def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False): """ Args: task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks @@ -200,14 +222,28 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True): if isinstance(task_attrs, dict): task_attrs = [task_attrs] + ret = [] + names = [] + start = 0 if insert_taskid: - ret = [([1,1], 'int64')] - names = ['__task_id'] - start = 1 - else: - ret = [] - names = [] - start = 0 + ret.append(([1,1], 'int64')) + names.append('__task_id') + start += 1 + + if insert_batchsize: + ret.append(([1], 'int64')) + names.append('batch_size') + start += 1 + + if insert_seqlen: + ret.append(([1], 'int64')) + names.append('seqlen') + start += 1 + + if insert_batchsize_x_seqlen: + ret.append(([1], 'int64')) + names.append('batchsize_x_seqlen') + start += 1 names += sorted(backbone_attr.keys()) ret.extend([backbone_attr[k] for k in names[start:]]) diff --git a/run_demo1.sh b/run_demo1.sh index a73cb1b..3f3d8ec 100755 --- a/run_demo1.sh +++ b/run_demo1.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=0 python demo1.py diff --git a/run_demo2.sh b/run_demo2.sh index 56376ef..02c40ba 100755 --- a/run_demo2.sh +++ b/run_demo2.sh @@ -1,6 +1,5 @@ -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -while true -do - python demo2.py -done +export CUDA_VISIBLE_DEVICES=0 + +python -u demo2.py +# GLOG_vmodule=lookup_table_op=4 python -u demo2.py > debug2.log 2>&1 diff --git a/run_demo3.sh b/run_demo3.sh index d32afd2..1e2c7c3 100755 --- a/run_demo3.sh +++ b/run_demo3.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=0,1 +export CUDA_VISIBLE_DEVICES=0 python demo3.py -- GitLab