From cde7404eeceef251a6b9752ad19a639c2f4fa87e Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 31 Oct 2019 17:50:27 +0800 Subject: [PATCH] fix bugs and add demo3 --- config_demo3.yaml | 10 ++++--- data/cls4mrqa/train.tsv | 2 +- demo3_tasks/cls1.yaml | 5 ++++ demo3_tasks/cls2.yaml | 5 ++++ demo3_tasks/cls3.yaml | 5 ++++ demo3_tasks/cls4.yaml | 5 ++++ demo3_tasks/cls5.yaml | 5 ++++ demo3_tasks/cls6.yaml | 5 ++++ demo3_tasks/mlm4mrqa.yaml | 3 -- demo3_tasks/mrqa.yaml | 3 -- paddlepalm/mtl_controller.py | 47 ++++++++++++++++--------------- paddlepalm/reader/cls.py | 3 +- paddlepalm/task_instance.py | 10 +++---- paddlepalm/task_paradigm/cls.py | 13 ++++----- paddlepalm/task_paradigm/match.py | 6 ++-- paddlepalm/task_paradigm/mlm.py | 10 +++---- paddlepalm/task_paradigm/mrc.py | 6 ++-- run_demo2.sh | 6 +++- 18 files changed, 89 insertions(+), 60 deletions(-) create mode 100644 demo3_tasks/cls1.yaml create mode 100644 demo3_tasks/cls2.yaml create mode 100644 demo3_tasks/cls3.yaml create mode 100644 demo3_tasks/cls4.yaml create mode 100644 demo3_tasks/cls5.yaml create mode 100644 demo3_tasks/cls6.yaml delete mode 100644 demo3_tasks/mlm4mrqa.yaml delete mode 100644 demo3_tasks/mrqa.yaml diff --git a/config_demo3.yaml b/config_demo3.yaml index 7731a8d..6ce4a61 100644 --- a/config_demo3.yaml +++ b/config_demo3.yaml @@ -1,6 +1,8 @@ -task_instance: "mlm4mrqa" +task_instance: "cls1, cls2, cls3, cls4, cls5, cls6" -save_path: "output_model/firstrun" +task_reuse_tag: 0,0,1,1,0,2 + +save_path: "output_model/thirdrun" backbone: "ernie" backbone_config_path: "pretrain_model/ernie/ernie_config.json" @@ -9,8 +11,8 @@ vocab_path: "pretrain_model/ernie/vocab.txt" do_lower_case: True max_seq_len: 512 -batch_size: 5 -num_epochs: 100 +batch_size: 4 +num_epochs: 0.5 optimizer: "adam" learning_rate: 3e-5 warmup_proportion: 0.1 diff --git a/data/cls4mrqa/train.tsv b/data/cls4mrqa/train.tsv index eace0f4..8b565c4 100644 --- a/data/cls4mrqa/train.tsv +++ b/data/cls4mrqa/train.tsv @@ -29,7 +29,7 @@ label text_a 0 Trophy hunting can include areas which would likely be unsuitable for what other types of ecotourism?study states that less than 3% of a trophy hunters' expenditures reach the local level, meaning that the economic incentive and benefit is "minimal, particularly when we consider the vast areas of 1 In simple language, what are the interconnections in an embedding matrix?Since it was quite easy to stack interconnections (wires) inside the embedding matrix, the approach allowed designers to forget completely about the routing of wires (usually a time-consuming operation of PCB design): Anywhere the designer needs a connection, the machine will draw a wire in straight line from one location/pin 2 rho has been to the most all star games in baseballn4
  • Stan Musial 24
  • -0In 1169, Ireland was invaded by which people?High King to ensure the terms of the Treaty of Windsor led Henry II, as King of England, to rule as effective monarch under the title of Lord of Ireland. This title was granted to his younger son but when Henry's heir unexpectedly died the title of King of England and Lord of Ireland became entwined in one +0 In 1169, Ireland was invaded by which people?High King to ensure the terms of the Treaty of Windsor led Henry II, as King of England, to rule as effective monarch under the title of Lord of Ireland. This title was granted to his younger son but when Henry's heir unexpectedly died the title of King of England and Lord of Ireland became entwined in one 1 What year did a biracial Populist fusion gain the Governors office?to the legislature and governor's office, but the Populists attracted voters displeased with them. In 1896 a biracial, Populist-Republican Fusionist coalition gained the governor's office. The Democrats regained control of the legislature 1 nearest metro station to majnu ka tilla delhiRing Road of Delhi . It is at a walkable distance from ISBT Kashmere Gate . It is approachable through the Kashmeri Gate station of the Delhi Metro , lies on both the Red ( Dilshad Garden - Rithala ) and Yellow Lines ( Samaypur 3 where is california located in the region of the united states

    California is a U.S. state in the Pacific Region of the United States . With 39.5 million residents , California is the most populous state in the United States and the third largest by area . The diff --git a/demo3_tasks/cls1.yaml b/demo3_tasks/cls1.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls1.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/cls2.yaml b/demo3_tasks/cls2.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls2.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/cls3.yaml b/demo3_tasks/cls3.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls3.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/cls4.yaml b/demo3_tasks/cls4.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls4.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/cls5.yaml b/demo3_tasks/cls5.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls5.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/cls6.yaml b/demo3_tasks/cls6.yaml new file mode 100644 index 0000000..79f548d --- /dev/null +++ b/demo3_tasks/cls6.yaml @@ -0,0 +1,5 @@ +train_file: "data/cls4mrqa/train.tsv" +reader: cls +paradigm: cls + +n_classes: 4 diff --git a/demo3_tasks/mlm4mrqa.yaml b/demo3_tasks/mlm4mrqa.yaml deleted file mode 100644 index c24296e..0000000 --- a/demo3_tasks/mlm4mrqa.yaml +++ /dev/null @@ -1,3 +0,0 @@ -train_file: "data/mlm4mrqa/train.tsv" -reader: mlm -paradigm: mlm diff --git a/demo3_tasks/mrqa.yaml b/demo3_tasks/mrqa.yaml deleted file mode 100644 index 4faa33e..0000000 --- a/demo3_tasks/mrqa.yaml +++ /dev/null @@ -1,3 +0,0 @@ -train_file: data/cls4mrqa/train.tsv -reader: cls -paradigm: cls diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 12fbd0b..40aef0c 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -208,7 +208,7 @@ class Controller(object): self.exe = exe self.dev_count = dev_count - print_dict(mtl_conf, title='main configuration') + print_dict(mtl_conf, title='global configuration') # parse task instances and target tags instnames = _parse_list(mtl_conf['task_instance']) @@ -230,6 +230,18 @@ class Controller(object): instname_to_conf[instname] = conf instname_to_id[instname] = id + # prepare backbone + if 'backbone_config_path' in mtl_conf: + bb_conf = _parse_json(mtl_conf['backbone_config_path']) + bb_conf = _merge_conf(mtl_conf, bb_conf) + else: + bb_conf = mtl_conf + print_dict(bb_conf, title='backbone configuration'.format(instname)) + + bb_name = mtl_conf['backbone'] + bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) + Backbone = getattr(bb_mod, 'Model') + # create task instances instances = [] for name in instnames: @@ -294,8 +306,8 @@ class Controller(object): for j in range(i): if tags[i] == tags[j]: # check paradigm of reused tasks - assert instances[i].task_paradigm == \ - instances[j].task_paradigm, \ + assert instances[i].Paradigm == \ + instances[j].Paradigm, \ "paradigm of reuse tasks should be consistent" instances[i].task_reuse_scope = instances[j].name break @@ -312,18 +324,6 @@ class Controller(object): inst.Reader = Reader inst.Paradigm = Paradigm - - # prepare backbone - if 'backbone_config_path' in mtl_conf: - bb_conf = _parse_json(mtl_conf['backbone_config_path']) - bb_conf = _merge_conf(mtl_conf, bb_conf) - else: - bb_conf = mtl_conf - print_dict(bb_conf, title='backbone configuration'.format(instname)) - - bb_name = mtl_conf['backbone'] - bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) - Backbone = getattr(bb_mod, 'Model') self.instances = instances self.mrs = mrs @@ -433,11 +433,12 @@ class Controller(object): train_prog = fluid.default_main_program() train_init_prog = fluid.default_startup_program() # 别用unique_name.guard了,没用的,无法作用到param_attr里的name上 - with fluid.unique_name.guard("backbone-"): - bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') + # with fluid.unique_name.guard("backbone-"): + bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) - #for var in train_init_prog.blocks[0].vars: - # print(var) + # for block in train_init_prog.blocks: + # for var in block.vars: + # print(var) # 会挂 # 这里是否有必要新建一个program?是的,被坑死了 @@ -445,10 +446,10 @@ class Controller(object): pred_init_prog = fluid.Program() with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): + # with fluid.unique_name.guard(): pred_net_inputs = create_net_inputs(pred_input_attrs) # 别用unique_name.guard了,没用的,无法作用到param_attr里的name上 # with fluid.unique_name.guard("backbone-"): - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') fluid.framework.switch_main_program(train_prog) @@ -465,7 +466,7 @@ class Controller(object): scope = inst.task_reuse_scope + '/' with fluid.unique_name.guard(scope): - output_vars = inst.build_task_layer(task_inputs, phase='train') + output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} old = len(task_output_vars) # for debug task_output_vars.update(output_vars) @@ -485,8 +486,9 @@ class Controller(object): inst.pred_input = cur_inputs pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} scope = inst.task_reuse_scope + '/' + # 注意,这里不加上fluid.unique_name.guard会挂 with fluid.unique_name.guard(scope): - inst.build_task_layer(pred_task_inputs, phase='pred') + inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) bb_fetches = {k: v.name for k,v in bb_output_vars.items()} @@ -668,6 +670,7 @@ class Controller(object): # save_path = os.path.join(main_conf['save_path'], # "step_" + str(global_step) + "_final") # fluid.io.save_persistables(self.exe, save_path, saver_program) + print("ALL tasks train finished, exiting...") def pred(self, task_instance, inference_model_dir=None): if self._for_train: diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 286226d..1ecf6cb 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -41,7 +41,7 @@ class Reader(reader): if phase == 'train': self._input_file = config['train_file'] self._num_epochs = None # 防止iteartor终止 - self._shuffle = config.get('shuffle', False) + self._shuffle = config.get('shuffle', True) # self._shuffle_buffer = config.get('shuffle_buffer', 5000) elif phase == 'eval': self._input_file = config['dev_file'] @@ -93,7 +93,6 @@ class Reader(reader): return outputs for batch in self._data_generator(): - print(batch) yield list_to_dict(batch) def get_epoch_outputs(self): diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py index d6d1c8d..1d3b7a6 100644 --- a/paddlepalm/task_instance.py +++ b/paddlepalm/task_instance.py @@ -67,8 +67,8 @@ class TaskInstance(object): 'fetch_list': 'self._pred_fetch_name_list'} - def build_task_layer(self, net_inputs, phase): - output_vars = self._task_layer[phase].build(net_inputs) + def build_task_layer(self, net_inputs, phase, scope=""): + output_vars = self._task_layer[phase].build(net_inputs, scope_name=scope) if phase == 'pred': if output_vars is not None: self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) @@ -90,10 +90,10 @@ class TaskInstance(object): # del self._pred_input_varname_list[0] # del self._pred_input_varname_list[0] # del self._pred_input_varname_list[0] + # print(self._pred_input_varname_list) fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, params_filename='__params__') - print(self._name + ': inference model saved at ' + dirpath) conf = {} for k, strv in self._save_protocol.items(): @@ -101,6 +101,7 @@ class TaskInstance(object): conf[k] = v with open(os.path.join(dirpath, '__conf__'), 'w') as writer: writer.write(json.dumps(conf, indent=1)) + print(self._name + ': inference model saved at ' + dirpath) def load(self, infer_model_path=None): if infer_model_path is None: @@ -273,9 +274,6 @@ class TaskInstance(object): - - - def check_instances(insts): diff --git a/paddlepalm/task_paradigm/cls.py b/paddlepalm/task_paradigm/cls.py index d86afe3..75158aa 100644 --- a/paddlepalm/task_paradigm/cls.py +++ b/paddlepalm/task_paradigm/cls.py @@ -21,7 +21,7 @@ class TaskParadigm(task_paradigm): ''' classification ''' - def __init___(self, config, phase, backbone_config=None): + def __init__(self, config, phase, backbone_config=None): self._is_training = phase == 'train' self._hidden_size = backbone_config['hidden_size'] self.num_classes = config['n_classes'] @@ -50,13 +50,12 @@ class TaskParadigm(task_paradigm): if self._is_training: return {'loss': [[1], 'float32']} else: - return {'logits': [-1, self.num_classes], 'float32'} + return {'logits': [[-1, self.num_classes], 'float32']} - def build(self, **inputs): + def build(self, inputs, scope_name=''): sent_emb = inputs['backbone']['sentence_embedding'] - label_ids = inputs['reader']['label_ids'] - if self._is_training: + label_ids = inputs['reader']['label_ids'] cls_feats = fluid.layers.dropout( x=sent_emb, dropout_prob=self._dropout_prob, @@ -66,10 +65,10 @@ class TaskParadigm(task_paradigm): input=sent_emb, size=self.num_classes, param_attr=fluid.ParamAttr( - name="cls_out_w", + name=scope_name+"cls_out_w", initializer=self._param_initializer), bias_attr=fluid.ParamAttr( - name="cls_out_b", initializer=fluid.initializer.Constant(0.))) + name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.))) if self._is_training: loss = fluid.layers.softmax_with_cross_entropy( diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task_paradigm/match.py index 698bedb..8341808 100644 --- a/paddlepalm/task_paradigm/match.py +++ b/paddlepalm/task_paradigm/match.py @@ -52,7 +52,7 @@ class TaskParadigm(task_paradigm): else: return {"logits": [[-1, 1], 'float32']} - def build(self, inputs): + def build(self, inputs, scope_name=""): if self._is_training: labels = inputs["reader"]["label_ids"] cls_feats = inputs["backbone"]["sentence_pair_embedding"] @@ -67,10 +67,10 @@ class TaskParadigm(task_paradigm): input=cls_feats, size=2, param_attr=fluid.ParamAttr( - name="cls_out_w", + name=scope_name+"cls_out_w", initializer=self._param_initializer), bias_attr=fluid.ParamAttr( - name="cls_out_b", + name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.))) if self._is_training: diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task_paradigm/mlm.py index c82641a..0ec5d31 100644 --- a/paddlepalm/task_paradigm/mlm.py +++ b/paddlepalm/task_paradigm/mlm.py @@ -50,7 +50,7 @@ class TaskParadigm(task_paradigm): else: return {"logits": [[-1], 'float32']} - def build(self, inputs): + def build(self, inputs, scope_name=""): mask_pos = inputs["reader"]["mask_pos"] if self._is_training: mask_label = inputs["reader"]["mask_label"] @@ -79,15 +79,15 @@ class TaskParadigm(task_paradigm): size=emb_size, act=self._hidden_act, param_attr=fluid.ParamAttr( - name='mask_lm_trans_fc.w_0', + name=scope_name+'mask_lm_trans_fc.w_0', initializer=_param_initializer), - bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0')) + bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0')) # transform: layer norm mask_trans_feat = pre_process_layer( - mask_trans_feat, 'n', name='mask_lm_trans') + mask_trans_feat, 'n', name=scope_name+'mask_lm_trans') mask_lm_out_bias_attr = fluid.ParamAttr( - name="mask_lm_out_fc.b_0", + name=scope_name+"mask_lm_out_fc.b_0", initializer=fluid.initializer.Constant(value=0.0)) # print fluid.default_main_program().global_block() diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task_paradigm/mrc.py index 795dd5b..b1f0b56 100644 --- a/paddlepalm/task_paradigm/mrc.py +++ b/paddlepalm/task_paradigm/mrc.py @@ -73,7 +73,7 @@ class TaskParadigm(task_paradigm): 'unique_ids': [[-1, 1], 'int64']} - def build(self, inputs): + def build(self, inputs, scope_name=""): if self._is_training: start_positions = inputs['reader']['start_positions'] end_positions = inputs['reader']['end_positions'] @@ -91,10 +91,10 @@ class TaskParadigm(task_paradigm): size=2, num_flatten_dims=2, param_attr=fluid.ParamAttr( - name="cls_squad_out_w", + name=scope_name+"cls_squad_out_w", initializer=fluid.initializer.TruncatedNormal(scale=0.02)), bias_attr=fluid.ParamAttr( - name="cls_squad_out_b", initializer=fluid.initializer.Constant(0.))) + name=scope_name+"cls_squad_out_b", initializer=fluid.initializer.Constant(0.))) logits = fluid.layers.transpose(x=logits, perm=[2, 0, 1]) start_logits, end_logits = fluid.layers.unstack(x=logits, axis=0) diff --git a/run_demo2.sh b/run_demo2.sh index 128910e..232d3c8 100755 --- a/run_demo2.sh +++ b/run_demo2.sh @@ -1,4 +1,8 @@ +set -e export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -python -u demo2.py +while true +do + python -u demo2.py +done # GLOG_vmodule=lookup_table_op=4 python -u demo2.py > debug2.log 2>&1 -- GitLab