diff --git a/paddlepalm/distribute/reader.py b/paddlepalm/distribute/reader.py index 6532e0d68b0d12d838f3f41d248efea8a2eb908b..8b36569b90152643712a33a49c238873df5757fd 100644 --- a/paddlepalm/distribute/reader.py +++ b/paddlepalm/distribute/reader.py @@ -90,7 +90,6 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): queue.task_done() if ret is not None: batches, num_pad = ret - id = batches[0]['__task_id'][0][0] if phase == 'train' else -1 batch_buf = [] flag_buf = [] for idx, batch in enumerate(batches): @@ -98,10 +97,11 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): flag = idx-len(batches) < -num_pad # if num_pad > 0: # num_pad -= 1 - batch = postprocess_fn(batch, id) + # batch = postprocess_fn(batch, id) + batch = postprocess_fn(batch) batch_buf.append(batch) flag_buf.append(flag) - yield batch_buf, flag_buf, id + yield batch_buf, flag_buf else: break queue.join() diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index b65b066ccde747f98df311d45a48cd5fceec9c3e..8b86e53dfef76e5a01631ac9b60bca1a728e9e38 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -3,6 +3,7 @@ from paddle import fluid from paddle.fluid import layers from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm import Trainer +from paddlepalm.utils.reader_helper import create_multihead_iterator_fn, create_multihead_feed_batch_process_fn dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count VERBOSE=False @@ -63,14 +64,6 @@ class MultiHeadTrainer(Trainer): head = head_dict[self._trainers[i].name] # loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog) loss_var = self._trainers[i].build_forward(backbone, head) - print(self._trainers[i].name) - print(self._trainers[i].name) - print(self._trainers[i].name) - print(self._trainers[i].name) - print(i) - print(i) - print(i) - print(i) return loss_var # task_fns = {} @@ -82,8 +75,9 @@ class MultiHeadTrainer(Trainer): # task_fns[i] = task_loss() - task_fns = {i: lambda: get_loss(i) for i in range(num_heads)} - print(task_fns) + + # task_fns = {i: lambda: get_loss(i) for i in range(num_heads)} + task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)} with fluid.program_guard(train_prog, train_init_prog): head_id_var = fluid.data(name="branch",shape=[1],dtype='int64') @@ -115,7 +109,7 @@ class MultiHeadTrainer(Trainer): trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference]) base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch - name_to_position = [] + input_names = [] joint_shape_and_dtypes = [] iterators = [] prefixes = [] @@ -126,9 +120,9 @@ class MultiHeadTrainer(Trainer): assert t.name in reader_dict assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \ To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name) - print(num_epochs, t.mix_ratio, base_steps_pur_epoch) + # print(num_epochs, t.mix_ratio, base_steps_pur_epoch) max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) - if not t.set_as_aux: + if not t._as_auxilary: print('{}: expected train steps {}.'.format(t.name, max_train_steps)) global_steps += max_train_steps if t.name != sampling_reference: @@ -137,14 +131,14 @@ class MultiHeadTrainer(Trainer): prefixes.append(t.name) mrs.append(t.mix_ratio) iterators.append(t._raw_iterator_fn()) - name_to_position.append(t._name_to_position) + input_names.append(t._input_names) joint_shape_and_dtypes.append(t._shape_and_dtypes) print('Estimated overall train steps {}.'.format(global_steps)) self._overall_train_steps = global_steps - iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ - mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') + iterator_fn = create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ + mrs, input_names, dev_count=dev_count) feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs) if gpu_dev_count > 1: @@ -187,8 +181,8 @@ class MultiHeadTrainer(Trainer): time_begin = time.time() for feed in iterator: print(feed) - batch, task_id = feed - rt_outputs = self.train_one_step(batch, task_id) + # batch, task_id = feed + rt_outputs, task_id = self.train_one_step(feed, task_id) task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')} self._task_head.batch_postprocess(task_rt_outputs) @@ -222,20 +216,23 @@ class MultiHeadTrainer(Trainer): break - def train_one_step(self, batch, task_id): + def train_one_step(self, batch): if dev_count > 1: assert isinstance(batch, list) - for f in batch: - f['branch'] = np.array([task_id], dtype='int64') + # for f in batch: + # f['branch'] = np.array([task_id], dtype='int64') + task_id = batch[0]['__task_id'] else: assert isinstance(batch, dict) - batch['branch'] = np.array([task_id], dtype='int64') + task_id = batch['__task_id'] + # batch['branch'] = np.array([task_id], dtype='int64') # feed = self._trainers[task_id].get_one_batch() rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog) self._cur_train_steps += 1 + return rt_outputs, task_id # if dev_count > 1: # # feed, mask, task_id = batch diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 30c0a78c6d5c3f71446cdfd3e9d44bca0e924112..2022c5ca8d9d0464bc6aaae0ba3c634b3204221e 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -41,7 +41,7 @@ class Trainer(object): self._train_init = False self._predict_init = False - self._check_save = lambda: False + nelf._check_save = lambda: False # if save_predict_model: # self._save_predict_model = True @@ -97,10 +97,12 @@ class Trainer(object): pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name) # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader'] + # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') + # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False) pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)] self._pred_shape_and_dtypes = pred_shape_and_dtypes self._pred_name_to_position = pred_name_to_position @@ -161,10 +163,11 @@ class Trainer(object): # merge reader input attrs from backbone and task_instances - input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False) # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] self._shape_and_dtypes = shape_and_dtypes self._name_to_position = name_to_position + self._input_names = input_names if DEBUG: print('----- for debug -----') @@ -237,8 +240,6 @@ class Trainer(object): # print("[debug] : %d, %s" % (_id, var)) self._loss_var = loss_var return loss_var - - def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999): # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer." # build optimizer assert self._train_init_prog is not None, "train graph not foung! You should build_forward first." @@ -285,6 +286,9 @@ class Trainer(object): # print(self._train_prog) + def set_as_aux(self): + self._as_auxilary = True + def fit_reader(self, reader, phase='train'): # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer." # load data @@ -316,6 +320,11 @@ class Trainer(object): print('ok!') + # merge dataset iterators and create net input vars + iterator = reader._iterator() + prefix = self.name + + # merge dataset iterators and create net input vars iterator = reader._iterator() prefix = self.name @@ -659,10 +668,5 @@ class Trainer(object): @mix_ratio.setter def mix_ratio(self, value): - self._mix_ratio = float(value) - if self._verbose: - print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) - - def _set_lock(self): self._lock = True diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index d0727aaa2c8a8ecd6692c15fa8c219a2a23bf1ff..0f8f046bc9ecfa7b5019d5d628a52e066de76401 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -36,24 +36,24 @@ def create_feed_batch_process_fn(net_inputs): return feed_batch_process_fn -def create_multihead_feed_batch_process_fn(net_inputs): - - def feed_batch_process_fn(data, id=-1): - # temps = {} - # for i in range(len(net_inputs)): - temp = {} - inputs = net_inputs[id] if id != -1 else net_inputs - - for q, var in inputs.items(): - if isinstance(var, str) or isinstance(var, unicode): - temp[var] = data[q] - else: - temp[var.name] = data[q] - # temps[i] = temp - - return temp - - return feed_batch_process_fn +# def create_multihead_feed_batch_process_fn(net_inputs): +# +# def feed_batch_process_fn(data, id=-1): +# # temps = {} +# # for i in range(len(net_inputs)): +# temp = {} +# inputs = net_inputs[id] if id != -1 else net_inputs +# +# for q, var in inputs.items(): +# if isinstance(var, str) or isinstance(var, unicode): +# temp[var] = data[q] +# else: +# temp[var.name] = data[q] +# # temps[i] = temp +# +# return temp +# +# return feed_batch_process_fn def _check_and_adapt_shape_dtype(rt_val, attr, message=""): @@ -147,6 +147,41 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p return iterator_fn +def create_multihead_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, names, dev_count=1, keep_one_task=True): + task_ids = range(len(iterators)) + weights = [mr / float(sum(mrs)) for mr in mrs] + if not keep_one_task: + dev_count = 1 + + pos_to_outname = {j:i for i,j in outname_to_pos.items()} + + def iterator(): + while True: + id = np.random.choice(task_ids, p=weights) + task_id_tensor = np.array([id]).astype("int64") + + for i in range(dev_count): + + outputs = next(iterators[id]) # dict type + + prefix = iterator_prefixes[id] + results = {} + results['__task_id'] = task_id_tensor + for outname, val in outputs.items(): + task_outname = prefix + '.' + outname + + if outname in names[id]: + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ') + results[outname] = val + + if task_outname in names[id]: + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ') + results[task_outname] = val + + yield results + + return iterator + def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0): """ @@ -250,7 +285,7 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype return iterator -def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=True, insert_seqlen=True, insert_batchsize_x_seqlen=True): +def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False): """ Args: task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks @@ -262,7 +297,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc names = [] start = 0 if insert_taskid: - ret.append(([1,1], 'int64')) + ret.append(([1], 'int64')) names.append('__task_id') start += 1 @@ -283,16 +318,11 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc names += sorted(backbone_attr.keys()) ret.extend([backbone_attr[k] for k in names[start:]]) - name_to_position = {} # pos=0 is for task_id, thus we start from 1 - for pos, k in enumerate(names): - name_to_position[k] = pos for task_attr in task_attrs: task_names = sorted(task_attr.keys()) names.extend(task_names) ret.extend([task_attr[k] for k in task_names]) - for pos, k in enumerate(task_names, start=len(name_to_position)): - name_to_position[k] = pos - return names, ret, name_to_position + return names, ret