diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 63af1a456dbaa3ae3835976b6175a6803c876cbc..3d2bb5bcc1d35b3fe25128dc15e3c792ef2a1d7e 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -35,6 +35,9 @@ from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn from paddlepalm.default_settings import * from task_instance import TaskInstance, check_instances +import Queue +from threading import Thread + DEBUG=False VERBOSE=0 @@ -399,11 +402,14 @@ class Controller(object): prefixes.append(inst.name) mrs.append(inst.mix_ratio) - joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE) + joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') + self._joint_iterator_fn = joint_iterator_fn input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) + # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) + net_inputs = create_net_inputs(input_attrs, async=False) + self._net_inputs = net_inputs # build backbone and task layers train_prog = fluid.default_main_program() @@ -568,6 +574,18 @@ class Controller(object): return False return True + def pack_multicard_feed(iterator, net_inputs, dev_count): + ret = [] + mask = [] + for i in range(dev_count): + temp = {} + content, flag = next(iterator) + for q, var in net_inputs.items(): + temp[var.name] = content[q] + ret.append(temp) + mask.append(1 if flag else 0) + return ret, mask + # do training fetch_names, fetch_list = zip(*fetches.items()) @@ -576,8 +594,50 @@ class Controller(object): epoch = 0 time_begin = time.time() backbone_buffer = [] + + def multi_dev_reader(reader, dev_count): + def worker(reader, dev_count, queue): + dev_batches = [] + for index, data in enumerate(reader()): + if len(dev_batches) < dev_count: + dev_batches.append(data) + if len(dev_batches) == dev_count: + queue.put((dev_batches, 0)) + dev_batches = [] + # For the prediction of the remained batches, pad more batches to + # the number of devices and the padded samples would be removed in + # prediction outputs. + if len(dev_batches) > 0: + num_pad = dev_count - len(dev_batches) + for i in range(len(dev_batches), dev_count): + dev_batches.append(dev_batches[-1]) + queue.put((dev_batches, num_pad)) + queue.put(None) + + queue = Queue.Queue(dev_count*2) + p = Thread( + target=worker, args=(reader, dev_count, queue)) + p.daemon = True + p.start() + while True: + ret = queue.get() + if ret is not None: + batches, num_pad = ret + queue.task_done() + for batch in batches: + flag = num_pad == 0 + if num_pad > 0: + num_pad -= 1 + yield batch, flag + else: + break + queue.join() + + joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count) + while not train_finish(): - rt_outputs = self.exe.run(train_program, fetch_list=fetch_list) + feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count) + rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 544d4881d11d9acccbfd3a9aaa0538f6ff8c0cbc..efb047cfcd1497128deaa8ae60962fd4177fe0aa 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -105,11 +105,13 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p return iterator -def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0): +def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0, return_type='list'): """ joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 """ + pos_to_outname = {j:i for i,j in outname_to_pos.items()} + task_ids = range(len(iterators)) weights = [mr / float(sum(mrs)) for mr in mrs] if not keep_one_task: @@ -202,7 +204,13 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype print(np.shape(i)) print('') v -= 1 - yield results + if return_type == 'list': + yield results + elif return_type == 'dict': + temp = {} + for pos, i in enumerate(results): + temp[pos_to_outname[pos]] = i + yield temp return iterator