提交 5a95b380 编写于 作者: X xixiaoyao

fix reader

上级 7b82c250
...@@ -90,7 +90,6 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): ...@@ -90,7 +90,6 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
queue.task_done() queue.task_done()
if ret is not None: if ret is not None:
batches, num_pad = ret batches, num_pad = ret
id = batches[0]['__task_id'][0][0] if phase == 'train' else -1
batch_buf = [] batch_buf = []
flag_buf = [] flag_buf = []
for idx, batch in enumerate(batches): for idx, batch in enumerate(batches):
...@@ -98,10 +97,11 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): ...@@ -98,10 +97,11 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
flag = idx-len(batches) < -num_pad flag = idx-len(batches) < -num_pad
# if num_pad > 0: # if num_pad > 0:
# num_pad -= 1 # num_pad -= 1
batch = postprocess_fn(batch, id) # batch = postprocess_fn(batch, id)
batch = postprocess_fn(batch)
batch_buf.append(batch) batch_buf.append(batch)
flag_buf.append(flag) flag_buf.append(flag)
yield batch_buf, flag_buf, id yield batch_buf, flag_buf
else: else:
break break
queue.join() queue.join()
......
...@@ -3,6 +3,7 @@ from paddle import fluid ...@@ -3,6 +3,7 @@ from paddle import fluid
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
from paddlepalm import Trainer from paddlepalm import Trainer
from paddlepalm.utils.reader_helper import create_multihead_iterator_fn, create_multihead_feed_batch_process_fn
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
VERBOSE=False VERBOSE=False
...@@ -63,14 +64,6 @@ class MultiHeadTrainer(Trainer): ...@@ -63,14 +64,6 @@ class MultiHeadTrainer(Trainer):
head = head_dict[self._trainers[i].name] head = head_dict[self._trainers[i].name]
# loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog) # loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
loss_var = self._trainers[i].build_forward(backbone, head) loss_var = self._trainers[i].build_forward(backbone, head)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(i)
print(i)
print(i)
print(i)
return loss_var return loss_var
# task_fns = {} # task_fns = {}
...@@ -82,8 +75,9 @@ class MultiHeadTrainer(Trainer): ...@@ -82,8 +75,9 @@ class MultiHeadTrainer(Trainer):
# task_fns[i] = task_loss() # task_fns[i] = task_loss()
task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
print(task_fns) # task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)}
with fluid.program_guard(train_prog, train_init_prog): with fluid.program_guard(train_prog, train_init_prog):
head_id_var = fluid.data(name="branch",shape=[1],dtype='int64') head_id_var = fluid.data(name="branch",shape=[1],dtype='int64')
...@@ -115,7 +109,7 @@ class MultiHeadTrainer(Trainer): ...@@ -115,7 +109,7 @@ class MultiHeadTrainer(Trainer):
trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference]) trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference])
base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch
name_to_position = [] input_names = []
joint_shape_and_dtypes = [] joint_shape_and_dtypes = []
iterators = [] iterators = []
prefixes = [] prefixes = []
...@@ -126,9 +120,9 @@ class MultiHeadTrainer(Trainer): ...@@ -126,9 +120,9 @@ class MultiHeadTrainer(Trainer):
assert t.name in reader_dict assert t.name in reader_dict
assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \ assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \
To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name) To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name)
print(num_epochs, t.mix_ratio, base_steps_pur_epoch) # print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch)
if not t.set_as_aux: if not t._as_auxilary:
print('{}: expected train steps {}.'.format(t.name, max_train_steps)) print('{}: expected train steps {}.'.format(t.name, max_train_steps))
global_steps += max_train_steps global_steps += max_train_steps
if t.name != sampling_reference: if t.name != sampling_reference:
...@@ -137,14 +131,14 @@ class MultiHeadTrainer(Trainer): ...@@ -137,14 +131,14 @@ class MultiHeadTrainer(Trainer):
prefixes.append(t.name) prefixes.append(t.name)
mrs.append(t.mix_ratio) mrs.append(t.mix_ratio)
iterators.append(t._raw_iterator_fn()) iterators.append(t._raw_iterator_fn())
name_to_position.append(t._name_to_position) input_names.append(t._input_names)
joint_shape_and_dtypes.append(t._shape_and_dtypes) joint_shape_and_dtypes.append(t._shape_and_dtypes)
print('Estimated overall train steps {}.'.format(global_steps)) print('Estimated overall train steps {}.'.format(global_steps))
self._overall_train_steps = global_steps self._overall_train_steps = global_steps
iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ iterator_fn = create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \
mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') mrs, input_names, dev_count=dev_count)
feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs) feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1: if gpu_dev_count > 1:
...@@ -187,8 +181,8 @@ class MultiHeadTrainer(Trainer): ...@@ -187,8 +181,8 @@ class MultiHeadTrainer(Trainer):
time_begin = time.time() time_begin = time.time()
for feed in iterator: for feed in iterator:
print(feed) print(feed)
batch, task_id = feed # batch, task_id = feed
rt_outputs = self.train_one_step(batch, task_id) rt_outputs, task_id = self.train_one_step(feed, task_id)
task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')} task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')}
self._task_head.batch_postprocess(task_rt_outputs) self._task_head.batch_postprocess(task_rt_outputs)
...@@ -222,20 +216,23 @@ class MultiHeadTrainer(Trainer): ...@@ -222,20 +216,23 @@ class MultiHeadTrainer(Trainer):
break break
def train_one_step(self, batch, task_id): def train_one_step(self, batch):
if dev_count > 1: if dev_count > 1:
assert isinstance(batch, list) assert isinstance(batch, list)
for f in batch: # for f in batch:
f['branch'] = np.array([task_id], dtype='int64') # f['branch'] = np.array([task_id], dtype='int64')
task_id = batch[0]['__task_id']
else: else:
assert isinstance(batch, dict) assert isinstance(batch, dict)
batch['branch'] = np.array([task_id], dtype='int64') task_id = batch['__task_id']
# batch['branch'] = np.array([task_id], dtype='int64')
# feed = self._trainers[task_id].get_one_batch() # feed = self._trainers[task_id].get_one_batch()
rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog) rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog)
self._cur_train_steps += 1 self._cur_train_steps += 1
return rt_outputs, task_id
# if dev_count > 1: # if dev_count > 1:
# # feed, mask, task_id = batch # # feed, mask, task_id = batch
......
...@@ -41,7 +41,7 @@ class Trainer(object): ...@@ -41,7 +41,7 @@ class Trainer(object):
self._train_init = False self._train_init = False
self._predict_init = False self._predict_init = False
self._check_save = lambda: False nelf._check_save = lambda: False
# if save_predict_model: # if save_predict_model:
# self._save_predict_model = True # self._save_predict_model = True
...@@ -97,10 +97,12 @@ class Trainer(object): ...@@ -97,10 +97,12 @@ class Trainer(object):
pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name) pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
# pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader'] # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']
# _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
# _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
# _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
# _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False)
pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)] pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
self._pred_shape_and_dtypes = pred_shape_and_dtypes self._pred_shape_and_dtypes = pred_shape_and_dtypes
self._pred_name_to_position = pred_name_to_position self._pred_name_to_position = pred_name_to_position
...@@ -161,10 +163,11 @@ class Trainer(object): ...@@ -161,10 +163,11 @@ class Trainer(object):
# merge reader input attrs from backbone and task_instances # merge reader input attrs from backbone and task_instances
input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False)
# shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
self._shape_and_dtypes = shape_and_dtypes self._shape_and_dtypes = shape_and_dtypes
self._name_to_position = name_to_position self._name_to_position = name_to_position
self._input_names = input_names
if DEBUG: if DEBUG:
print('----- for debug -----') print('----- for debug -----')
...@@ -237,8 +240,6 @@ class Trainer(object): ...@@ -237,8 +240,6 @@ class Trainer(object):
# print("[debug] : %d, %s" % (_id, var)) # print("[debug] : %d, %s" % (_id, var))
self._loss_var = loss_var self._loss_var = loss_var
return loss_var return loss_var
def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
# assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer." # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer."
# build optimizer # build optimizer
assert self._train_init_prog is not None, "train graph not foung! You should build_forward first." assert self._train_init_prog is not None, "train graph not foung! You should build_forward first."
...@@ -285,6 +286,9 @@ class Trainer(object): ...@@ -285,6 +286,9 @@ class Trainer(object):
# print(self._train_prog) # print(self._train_prog)
def set_as_aux(self):
self._as_auxilary = True
def fit_reader(self, reader, phase='train'): def fit_reader(self, reader, phase='train'):
# assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer." # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer."
# load data # load data
...@@ -316,6 +320,11 @@ class Trainer(object): ...@@ -316,6 +320,11 @@ class Trainer(object):
print('ok!') print('ok!')
# merge dataset iterators and create net input vars
iterator = reader._iterator()
prefix = self.name
# merge dataset iterators and create net input vars # merge dataset iterators and create net input vars
iterator = reader._iterator() iterator = reader._iterator()
prefix = self.name prefix = self.name
...@@ -659,10 +668,5 @@ class Trainer(object): ...@@ -659,10 +668,5 @@ class Trainer(object):
@mix_ratio.setter @mix_ratio.setter
def mix_ratio(self, value): def mix_ratio(self, value):
self._mix_ratio = float(value)
if self._verbose:
print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))
def _set_lock(self):
self._lock = True self._lock = True
...@@ -36,24 +36,24 @@ def create_feed_batch_process_fn(net_inputs): ...@@ -36,24 +36,24 @@ def create_feed_batch_process_fn(net_inputs):
return feed_batch_process_fn return feed_batch_process_fn
def create_multihead_feed_batch_process_fn(net_inputs): # def create_multihead_feed_batch_process_fn(net_inputs):
#
def feed_batch_process_fn(data, id=-1): # def feed_batch_process_fn(data, id=-1):
# temps = {} # # temps = {}
# for i in range(len(net_inputs)): # # for i in range(len(net_inputs)):
temp = {} # temp = {}
inputs = net_inputs[id] if id != -1 else net_inputs # inputs = net_inputs[id] if id != -1 else net_inputs
#
for q, var in inputs.items(): # for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode): # if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q] # temp[var] = data[q]
else: # else:
temp[var.name] = data[q] # temp[var.name] = data[q]
# temps[i] = temp # # temps[i] = temp
#
return temp # return temp
#
return feed_batch_process_fn # return feed_batch_process_fn
def _check_and_adapt_shape_dtype(rt_val, attr, message=""): def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
...@@ -147,6 +147,41 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p ...@@ -147,6 +147,41 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p
return iterator_fn return iterator_fn
def create_multihead_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, names, dev_count=1, keep_one_task=True):
task_ids = range(len(iterators))
weights = [mr / float(sum(mrs)) for mr in mrs]
if not keep_one_task:
dev_count = 1
pos_to_outname = {j:i for i,j in outname_to_pos.items()}
def iterator():
while True:
id = np.random.choice(task_ids, p=weights)
task_id_tensor = np.array([id]).astype("int64")
for i in range(dev_count):
outputs = next(iterators[id]) # dict type
prefix = iterator_prefixes[id]
results = {}
results['__task_id'] = task_id_tensor
for outname, val in outputs.items():
task_outname = prefix + '.' + outname
if outname in names[id]:
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[outname] = val
if task_outname in names[id]:
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
results[task_outname] = val
yield results
return iterator
def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0): def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0):
""" """
...@@ -250,7 +285,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -250,7 +285,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
return iterator return iterator
def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=True, insert_seqlen=True, insert_batchsize_x_seqlen=True): def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False):
""" """
Args: Args:
task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks
...@@ -262,7 +297,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc ...@@ -262,7 +297,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc
names = [] names = []
start = 0 start = 0
if insert_taskid: if insert_taskid:
ret.append(([1,1], 'int64')) ret.append(([1], 'int64'))
names.append('__task_id') names.append('__task_id')
start += 1 start += 1
...@@ -283,16 +318,11 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc ...@@ -283,16 +318,11 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc
names += sorted(backbone_attr.keys()) names += sorted(backbone_attr.keys())
ret.extend([backbone_attr[k] for k in names[start:]]) ret.extend([backbone_attr[k] for k in names[start:]])
name_to_position = {}
# pos=0 is for task_id, thus we start from 1 # pos=0 is for task_id, thus we start from 1
for pos, k in enumerate(names):
name_to_position[k] = pos
for task_attr in task_attrs: for task_attr in task_attrs:
task_names = sorted(task_attr.keys()) task_names = sorted(task_attr.keys())
names.extend(task_names) names.extend(task_names)
ret.extend([task_attr[k] for k in task_names]) ret.extend([task_attr[k] for k in task_names])
for pos, k in enumerate(task_names, start=len(name_to_position)): return names, ret
name_to_position[k] = pos
return names, ret, name_to_position
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册