diff --git a/paddlepalm/distribute/__init__.py b/paddlepalm/distribute/__init__.py index 255478be8053f52934e793000753fc41b41b0035..bff6594676be1cd2e5e7f3882903a67a50e1fcfa 100644 --- a/paddlepalm/distribute/__init__.py +++ b/paddlepalm/distribute/__init__.py @@ -5,5 +5,5 @@ import multiprocessing gpu_dev_count = int(fluid.core.get_cuda_device_count()) cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) -from reader import yield_pieces, data_feeder +from reader import yield_pieces, data_feeder, decode_fake diff --git a/paddlepalm/distribute/reader.py b/paddlepalm/distribute/reader.py index 95cfbf757ecfeb617ea0c95ad35b363b684ee70e..6532e0d68b0d12d838f3f41d248efea8a2eb908b 100644 --- a/paddlepalm/distribute/reader.py +++ b/paddlepalm/distribute/reader.py @@ -58,7 +58,6 @@ def yield_pieces(data, distribute_strategy, batch_size): yield temp def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): - if postprocess_fn is None: def postprocess_fn(batch): return batch @@ -108,3 +107,15 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): queue.join() +def decode_fake(nums, mask, bs): + n_t = 0 + for flag in mask: + if not flag: + break + n_t = n_t + 1 + + n_f = len(mask) - n_t + p1 = nums - (n_t-1) * bs + each_f = p1 / (n_f+1) + return each_f * n_f + diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index ada105df5bab1dd2754f6187af9e3269bb873bc9..2a0c7b85ce5ba63c3b49168be74cedd5f7da4e35 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs -from paddlepalm.distribute import data_feeder +from paddlepalm.distribute import data_feeder, decode_fake from default_settings import * from task_instance import TaskInstance, check_instances @@ -228,6 +228,7 @@ class Controller(object): exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) self.exe = exe self.dev_count = dev_count + self.batch_size = mtl_conf.get('batch_size') print_dict(mtl_conf, title='global configuration') @@ -350,7 +351,7 @@ class Controller(object): dev_count = self.dev_count num_instances = len(instances) mrs = self.mrs - branch = fluid.data(name="branch",shape=[1],dtype='int32') + branch = fluid.data(name="branch",shape=[1],dtype='int64') # set first_target/main task instance main_inst = None @@ -536,9 +537,8 @@ class Controller(object): # prepare for train self.train_backbone = train_backbone - #self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) + self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) self.saver_program = fluid.default_main_program() - self.train_program = self.saver_program self.main_inst = main_inst self.has_init_train = True @@ -564,7 +564,7 @@ class Controller(object): insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_prog = inst.load(infer_model_path) - # pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() + pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() if inst.reader['pred'] is None: pred_reader = inst.Reader(inst.config, phase='pred') inst.reader['pred'] = pred_reader @@ -628,8 +628,8 @@ class Controller(object): while not train_finish(): feed, mask, id = next(distribute_feeder) - - feed[0].update({'branch':np.array([id],dtype='int32')}) + for i in range(self.dev_count): + feed[i].update({'branch':np.array([id],dtype='int64')}) fetch_list.append(self._switched_loss) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_loss = rt_outputs.pop() @@ -714,33 +714,23 @@ class Controller(object): buf = [] for feed, mask, id in distribute_feeder: - # print('before run') + rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) - # print('after run') - splited_rt_outputs = [] - for item in rt_outputs: - splited_rt_outputs.append(np.split(item, len(mask))) - - # assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)] - # print(mask) - - while mask.pop() == False: - print(mask) - for item in splited_rt_outputs: + + nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size) + while nums_fake: + for item in rt_outputs: item.pop() - rt_outputs = [] - # print('cancat') - for item in splited_rt_outputs: - rt_outputs.append(np.concatenate(item)) - + nums_fake = nums_fake - 1 + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') - # print('leave feeder') + if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None - # print('epoch postprocess') + inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') @@ -754,4 +744,4 @@ if __name__ == '__main__': -__all__ = ["Controller"] \ No newline at end of file +__all__ = ["Controller"]