未验证 提交 cd731f08 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #17 from xixiaoyao/master

fix bugs and add demo3
task_instance: "mlm4mrqa"
task_instance: "cls1, cls2, cls3, cls4, cls5, cls6"
save_path: "output_model/firstrun"
task_reuse_tag: 0,0,1,1,0,2
save_path: "output_model/thirdrun"
backbone: "ernie"
backbone_config_path: "pretrain_model/ernie/ernie_config.json"
......@@ -9,8 +11,8 @@ vocab_path: "pretrain_model/ernie/vocab.txt"
do_lower_case: True
max_seq_len: 512
batch_size: 5
num_epochs: 100
batch_size: 4
num_epochs: 0.5
optimizer: "adam"
learning_rate: 3e-5
warmup_proportion: 0.1
......
......@@ -29,7 +29,7 @@ label text_a
0 Trophy hunting can include areas which would likely be unsuitable for what other types of ecotourism?study states that less than 3% of a trophy hunters' expenditures reach the local level, meaning that the economic incentive and benefit is "minimal, particularly when we consider the vast areas of
1 In simple language, what are the interconnections in an embedding matrix?Since it was quite easy to stack interconnections (wires) inside the embedding matrix, the approach allowed designers to forget completely about the routing of wires (usually a time-consuming operation of PCB design): Anywhere the designer needs a connection, the machine will draw a wire in straight line from one location/pin
2 rho has been to the most all star games in baseballn4 </Li> <Li> Stan Musial 24 </Li>
0In 1169, Ireland was invaded by which people?High King to ensure the terms of the Treaty of Windsor led Henry II, as King of England, to rule as effective monarch under the title of Lord of Ireland. This title was granted to his younger son but when Henry's heir unexpectedly died the title of King of England and Lord of Ireland became entwined in one
0 In 1169, Ireland was invaded by which people?High King to ensure the terms of the Treaty of Windsor led Henry II, as King of England, to rule as effective monarch under the title of Lord of Ireland. This title was granted to his younger son but when Henry's heir unexpectedly died the title of King of England and Lord of Ireland became entwined in one
1 What year did a biracial Populist fusion gain the Governors office?to the legislature and governor's office, but the Populists attracted voters displeased with them. In 1896 a biracial, Populist-Republican Fusionist coalition gained the governor's office. The Democrats regained control of the legislature
1 nearest metro station to majnu ka tilla delhiRing Road of Delhi . It is at a walkable distance from ISBT Kashmere Gate . It is approachable through the Kashmeri Gate station of the Delhi Metro , lies on both the Red ( Dilshad Garden - Rithala ) and Yellow Lines ( Samaypur
3 where is california located in the region of the united states<P> California is a U.S. state in the Pacific Region of the United States . With 39.5 million residents , California is the most populous state in the United States and the third largest by area . The
......
train_file: data/cls4mrqa/train.tsv
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/cls4mrqa/train.tsv"
reader: cls
paradigm: cls
n_classes: 4
train_file: "data/mlm4mrqa/train.tsv"
reader: mlm
paradigm: mlm
......@@ -208,7 +208,7 @@ class Controller(object):
self.exe = exe
self.dev_count = dev_count
print_dict(mtl_conf, title='main configuration')
print_dict(mtl_conf, title='global configuration')
# parse task instances and target tags
instnames = _parse_list(mtl_conf['task_instance'])
......@@ -230,6 +230,18 @@ class Controller(object):
instname_to_conf[instname] = conf
instname_to_id[instname] = id
# prepare backbone
if 'backbone_config_path' in mtl_conf:
bb_conf = _parse_json(mtl_conf['backbone_config_path'])
bb_conf = _merge_conf(mtl_conf, bb_conf)
else:
bb_conf = mtl_conf
print_dict(bb_conf, title='backbone configuration'.format(instname))
bb_name = mtl_conf['backbone']
bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name)
Backbone = getattr(bb_mod, 'Model')
# create task instances
instances = []
for name in instnames:
......@@ -294,8 +306,8 @@ class Controller(object):
for j in range(i):
if tags[i] == tags[j]:
# check paradigm of reused tasks
assert instances[i].task_paradigm == \
instances[j].task_paradigm, \
assert instances[i].Paradigm == \
instances[j].Paradigm, \
"paradigm of reuse tasks should be consistent"
instances[i].task_reuse_scope = instances[j].name
break
......@@ -313,18 +325,6 @@ class Controller(object):
inst.Reader = Reader
inst.Paradigm = Paradigm
# prepare backbone
if 'backbone_config_path' in mtl_conf:
bb_conf = _parse_json(mtl_conf['backbone_config_path'])
bb_conf = _merge_conf(mtl_conf, bb_conf)
else:
bb_conf = mtl_conf
print_dict(bb_conf, title='backbone configuration'.format(instname))
bb_name = mtl_conf['backbone']
bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name)
Backbone = getattr(bb_mod, 'Model')
self.instances = instances
self.mrs = mrs
self.Backbone = Backbone
......@@ -433,10 +433,11 @@ class Controller(object):
train_prog = fluid.default_main_program()
train_init_prog = fluid.default_startup_program()
# 别用unique_name.guard了,没用的,无法作用到param_attr里的name上
with fluid.unique_name.guard("backbone-"):
# with fluid.unique_name.guard("backbone-"):
bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_')
assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys())
#for var in train_init_prog.blocks[0].vars:
# for block in train_init_prog.blocks:
# for var in block.vars:
# print(var)
# 会挂
......@@ -445,10 +446,10 @@ class Controller(object):
pred_init_prog = fluid.Program()
with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog):
# with fluid.unique_name.guard():
pred_net_inputs = create_net_inputs(pred_input_attrs)
# 别用unique_name.guard了,没用的,无法作用到param_attr里的name上
# with fluid.unique_name.guard("backbone-"):
pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
fluid.framework.switch_main_program(train_prog)
......@@ -465,7 +466,7 @@ class Controller(object):
scope = inst.task_reuse_scope + '/'
with fluid.unique_name.guard(scope):
output_vars = inst.build_task_layer(task_inputs, phase='train')
output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope)
output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()}
old = len(task_output_vars) # for debug
task_output_vars.update(output_vars)
......@@ -485,8 +486,9 @@ class Controller(object):
inst.pred_input = cur_inputs
pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
scope = inst.task_reuse_scope + '/'
# 注意,这里不加上fluid.unique_name.guard会挂
with fluid.unique_name.guard(scope):
inst.build_task_layer(pred_task_inputs, phase='pred')
inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope)
bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
......@@ -668,6 +670,7 @@ class Controller(object):
# save_path = os.path.join(main_conf['save_path'],
# "step_" + str(global_step) + "_final")
# fluid.io.save_persistables(self.exe, save_path, saver_program)
print("ALL tasks train finished, exiting...")
def pred(self, task_instance, inference_model_dir=None):
if self._for_train:
......
......@@ -41,7 +41,7 @@ class Reader(reader):
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle = config.get('shuffle', True)
# self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
......@@ -93,7 +93,6 @@ class Reader(reader):
return outputs
for batch in self._data_generator():
print(batch)
yield list_to_dict(batch)
def get_epoch_outputs(self):
......
......@@ -67,8 +67,8 @@ class TaskInstance(object):
'fetch_list': 'self._pred_fetch_name_list'}
def build_task_layer(self, net_inputs, phase):
output_vars = self._task_layer[phase].build(net_inputs)
def build_task_layer(self, net_inputs, phase, scope=""):
output_vars = self._task_layer[phase].build(net_inputs, scope_name=scope)
if phase == 'pred':
if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
......@@ -90,10 +90,10 @@ class TaskInstance(object):
# del self._pred_input_varname_list[0]
# del self._pred_input_varname_list[0]
# del self._pred_input_varname_list[0]
# print(self._pred_input_varname_list)
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, params_filename='__params__')
print(self._name + ': inference model saved at ' + dirpath)
conf = {}
for k, strv in self._save_protocol.items():
......@@ -101,6 +101,7 @@ class TaskInstance(object):
conf[k] = v
with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
writer.write(json.dumps(conf, indent=1))
print(self._name + ': inference model saved at ' + dirpath)
def load(self, infer_model_path=None):
if infer_model_path is None:
......@@ -275,9 +276,6 @@ class TaskInstance(object):
def check_instances(insts):
"""to check ids, first_target"""
pass
......
......@@ -21,7 +21,7 @@ class TaskParadigm(task_paradigm):
'''
classification
'''
def __init___(self, config, phase, backbone_config=None):
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size']
self.num_classes = config['n_classes']
......@@ -50,13 +50,12 @@ class TaskParadigm(task_paradigm):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
return {'logits': [-1, self.num_classes], 'float32'}
return {'logits': [[-1, self.num_classes], 'float32']}
def build(self, **inputs):
def build(self, inputs, scope_name=''):
sent_emb = inputs['backbone']['sentence_embedding']
label_ids = inputs['reader']['label_ids']
if self._is_training:
label_ids = inputs['reader']['label_ids']
cls_feats = fluid.layers.dropout(
x=sent_emb,
dropout_prob=self._dropout_prob,
......@@ -66,10 +65,10 @@ class TaskParadigm(task_paradigm):
input=sent_emb,
size=self.num_classes,
param_attr=fluid.ParamAttr(
name="cls_out_w",
name=scope_name+"cls_out_w",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)))
if self._is_training:
loss = fluid.layers.softmax_with_cross_entropy(
......
......@@ -52,7 +52,7 @@ class TaskParadigm(task_paradigm):
else:
return {"logits": [[-1, 1], 'float32']}
def build(self, inputs):
def build(self, inputs, scope_name=""):
if self._is_training:
labels = inputs["reader"]["label_ids"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
......@@ -67,10 +67,10 @@ class TaskParadigm(task_paradigm):
input=cls_feats,
size=2,
param_attr=fluid.ParamAttr(
name="cls_out_w",
name=scope_name+"cls_out_w",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name="cls_out_b",
name=scope_name+"cls_out_b",
initializer=fluid.initializer.Constant(0.)))
if self._is_training:
......
......@@ -50,7 +50,7 @@ class TaskParadigm(task_paradigm):
else:
return {"logits": [[-1], 'float32']}
def build(self, inputs):
def build(self, inputs, scope_name=""):
mask_pos = inputs["reader"]["mask_pos"]
if self._is_training:
mask_label = inputs["reader"]["mask_label"]
......@@ -79,15 +79,15 @@ class TaskParadigm(task_paradigm):
size=emb_size,
act=self._hidden_act,
param_attr=fluid.ParamAttr(
name='mask_lm_trans_fc.w_0',
name=scope_name+'mask_lm_trans_fc.w_0',
initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name='mask_lm_trans')
mask_trans_feat, 'n', name=scope_name+'mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name="mask_lm_out_fc.b_0",
name=scope_name+"mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
# print fluid.default_main_program().global_block()
......
......@@ -73,7 +73,7 @@ class TaskParadigm(task_paradigm):
'unique_ids': [[-1, 1], 'int64']}
def build(self, inputs):
def build(self, inputs, scope_name=""):
if self._is_training:
start_positions = inputs['reader']['start_positions']
end_positions = inputs['reader']['end_positions']
......@@ -91,10 +91,10 @@ class TaskParadigm(task_paradigm):
size=2,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name="cls_squad_out_w",
name=scope_name+"cls_squad_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_squad_out_b", initializer=fluid.initializer.Constant(0.)))
name=scope_name+"cls_squad_out_b", initializer=fluid.initializer.Constant(0.)))
logits = fluid.layers.transpose(x=logits, perm=[2, 0, 1])
start_logits, end_logits = fluid.layers.unstack(x=logits, axis=0)
......
set -e
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u demo2.py
while true
do
python -u demo2.py
done
# GLOG_vmodule=lookup_table_op=4 python -u demo2.py > debug2.log 2>&1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册