diff --git a/README.md b/README.md index e57aafa43f55b61b18a9ca059d88cf1cf42699eb..aaaf6e2bf46751929416a0ccd8f7d88d39cdb429 100644 --- a/README.md +++ b/README.md @@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE #### 文本匹配数据集reader工具:match -该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`。数据集范例可参考`data/match4mrqa`中的数据集文件,格式形如 +该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,对于`pointwise`的学习策略,其中,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`;对于`pairwise`的学习策略,其中,一列为待匹配的样本`text_a`,其余为其对应的正例`text_b`和负例`text_b_neg`。格式形如 + +1. 学习策略为`pointwise`: ```yaml label text_a text_b @@ -642,10 +644,22 @@ label text_a text_b 1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday. ``` +2. 学习策略为`pairwise`: + +```yaml +text_a text_b text_b_neg +arrg ... ubuntunoob and ubuntu_user ... your nicks are confusing ^^ i d say it was **[TAB]** how that ... dynamic size of the c ontainer another idea would be an installation on an ( external ) flash-stick/card **[TAB]** will try now thanks if you have ati and md0 - i m no further help btw +got an error message while installing __number__ need help ( initrmfs ) mount failure error do you see this grub no a little more info would help ;-) did you boot a cd or pen drive to install or install from windows was this a install from windows whi ch is called a wubi how much memory does the computer have memory=ram so you got installed no errors and get this on reboot so when did you get this error did you burn it as a image **[TAB]** were you able to check the md5sum of the iso here is alink on md5sum i suspect it may not be this but never hurts to check __url__ **[TAB]** you would have to capture the pcl convert with hp 2xx then print that so do i set up another printer in cups with that driver but pointed to output to my cups pdf printer or do i need to pipe it through the driver on a lower level somehow +okay i come from a windows background .. currently running v __number__ __number__ and having a video card ( ati ) issue ... if i have an issue like this ( in windows ) i would go to the vendor site locate a current driver and install in ubuntu it aut omatically downloaded a driver - this driver i assume does not come from the vendor site but rather a ubuntu repository of tes ted/approved drivers is that a correct assumption yes that is correct **[TAB]** so given the downloaded driver is not performing properly i went to ati and found they have a newer version driver what is the correct process to load the new version do i ne ed to uninstall ( how ) the old version the new version is a run file - i am not familiar with what is the issue you re having with the ubuntu-supplied driver **[TAB]** ls -ld __path__ __path__ __path__ __path__ wrxr-xr-x +hey he wanted excitement __url__ __url__ dapper multivers thank you so much now i can do apt-get build-dep mythtv and compile it myself np i cannot install those packages i am also needing them why ca n't you install them i just verified they re insta llable i am on a default dapper install with all extra repositories in sources list uncommented and cant then you do n't have the correct repo enabled **[TAB]** lame installed ( none ) apt-cache policy lame **[TAB]** i am using mercury ... i think it is be tter than amsn i lost the curiosity for this __number__ years ago but i ve back are you using a router +``` + ***注意:数据集的第一列必须为header,即标注每一列的列名*** reader的输出(生成器每次yield出的数据)包含以下字段: +1. 学习策略为`pointwise`: + ```yaml token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。 position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 @@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE 当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。 +2. 学习策略为`pairwise`: + +```yaml +token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条正样本(文本对text_a text_b),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。 +position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 +segment_ids: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。 +input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。 +label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签,为0时表示两段文本不匹配,为1时代表构成匹配。 +task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。 +token_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本(文本对text_a text_b_neg),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。 +position_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本,其中的每个元素为文本中的每个token对应的位置id。 +segment_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b_neg)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。 +input_mask_neg: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。 +task_ids_neg: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。 +``` + #### 机器阅读理解数据集reader工具:mrc diff --git a/paddlepalm/__init__.py b/paddlepalm/__init__.py index 5522f0faf96ce48efa034e522b4d793b740bcb23..f2f4e9fc2e3a72907c9e58bc88e49e8d645dc0ea 100644 --- a/paddlepalm/__init__.py +++ b/paddlepalm/__init__.py @@ -1,9 +1,10 @@ import downloader from mtl_controller import Controller - +import distribute +from distribute import gpu_dev_count, cpu_dev_count del interface del task_instance del default_settings del utils -del mtl_controller \ No newline at end of file +del mtl_controller diff --git a/paddlepalm/distribute/__init__.py b/paddlepalm/distribute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..255478be8053f52934e793000753fc41b41b0035 --- /dev/null +++ b/paddlepalm/distribute/__init__.py @@ -0,0 +1,9 @@ +from paddle import fluid +import os +import multiprocessing + +gpu_dev_count = int(fluid.core.get_cuda_device_count()) +cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + +from reader import yield_pieces, data_feeder + diff --git a/paddlepalm/distribute/reader.py b/paddlepalm/distribute/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..95cfbf757ecfeb617ea0c95ad35b363b684ee70e --- /dev/null +++ b/paddlepalm/distribute/reader.py @@ -0,0 +1,110 @@ + +from . import gpu_dev_count, cpu_dev_count +import Queue +from threading import Thread + +dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count + +def yield_pieces(data, distribute_strategy, batch_size): + """ + Args: + distribute_strategy: support s=split, c=copy, u=unstack, + """ + assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count." + # print('data in yield pieces') + # print(len(data)) + + assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)] + assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)] + if isinstance(data, dict): + keys = list(data.keys()) + data_list = [data[i] for i in keys] + ds_list = [distribute_strategy[i] for i in keys] + else: + assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors." + data_list = data + ds_list = distribute_strategy + + stride = batch_size // dev_count + p = stride + # while p < len(data_list) + stride: + while p <= batch_size: + temp = [] + for d, s in zip(data_list, ds_list): + s = s.strip().lower() + if s == 's' or s == 'split': + if p - stride >= len(d): + print('WARNING: no more examples to feed empty devices') + temp = [] + return + temp.append(d[p-stride:p]) + elif s == 'u' or s == 'unstack': + assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.' + if p//stride > len(d): + print('WARNING: no more examples to feed empty devices') + return + temp.append(d[p//stride-1]) + elif s == 'c' or s == 'copy': + temp.append(d) + else: + raise NotImplementedError() + + p += stride + if type(data) == dict: + yield dict(zip(*[keys, temp])) + else: + # print('yielded pieces') + # print(len(temp)) + yield temp + +def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): + + if postprocess_fn is None: + def postprocess_fn(batch): + return batch + + def worker(reader, dev_count, queue): + dev_batches = [] + for index, data in enumerate(reader()): + if len(dev_batches) < dev_count: + dev_batches.append(data) + if len(dev_batches) == dev_count: + queue.put((dev_batches, 0)) + dev_batches = [] + # For the prediction of the remained batches, pad more batches to + # the number of devices and the padded samples would be removed in + # prediction outputs. + if len(dev_batches) > 0: + num_pad = dev_count - len(dev_batches) + for i in range(len(dev_batches), dev_count): + dev_batches.append(dev_batches[-1]) + queue.put((dev_batches, num_pad)) + queue.put(None) + + queue = Queue.Queue(dev_count*prefetch_steps) + p = Thread( + target=worker, args=(reader, dev_count, queue)) + p.daemon = True + p.start() + while True: + ret = queue.get() + queue.task_done() + if ret is not None: + batches, num_pad = ret + id = batches[0]['__task_id'][0][0] if phase == 'train' else -1 + batch_buf = [] + flag_buf = [] + for idx, batch in enumerate(batches): + # flag = num_pad == 0 + flag = idx-len(batches) < -num_pad + # if num_pad > 0: + # num_pad -= 1 + batch = postprocess_fn(batch, id) + batch_buf.append(batch) + flag_buf.append(flag) + yield batch_buf, flag_buf, id + else: + break + queue.join() + + diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 5193d3423f13e5a819214e6e52e6b5e2033fd10d..ada105df5bab1dd2754f6187af9e3269bb873bc9 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs +from paddlepalm.distribute import data_feeder -from paddlepalm.default_settings import * +from default_settings import * from task_instance import TaskInstance, check_instances -import Queue -from threading import Thread DEBUG=False VERBOSE=0 @@ -185,6 +184,27 @@ def _fit_attr(conf, fit_attr, strict=False): return conf +def create_feed_batch_process_fn(net_inputs): + + + def feed_batch_process_fn(data, id=-1): + # temps = {} + # for i in range(len(net_inputs)): + temp = {} + inputs = net_inputs[id] if id != -1 else net_inputs + + for q, var in inputs.items(): + if isinstance(var, str) or isinstance(var, unicode): + temp[var] = data[q] + else: + temp[var.name] = data[q] + # temps[i] = temp + + return temp + + return feed_batch_process_fn + + class Controller(object): def __init__(self, config, task_dir='.', for_train=True): @@ -330,6 +350,7 @@ class Controller(object): dev_count = self.dev_count num_instances = len(instances) mrs = self.mrs + branch = fluid.data(name="branch",shape=[1],dtype='int32') # set first_target/main task instance main_inst = None @@ -349,35 +370,51 @@ class Controller(object): # create reader, task # then check i/o across reader, backbone and task_layer - task_attrs = [] + + # check_fns = {} + task_attrs = {} pred_task_attrs = [] - for inst in instances: - train_reader = inst.Reader(inst.config, phase='train') - inst.reader['train'] = train_reader - train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) - inst.task_layer['train'] = train_parad - task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) - task_attrs.append(task_attr_from_reader) + joint_input_names = {} + joint_shape_and_dtypes = {} + name_to_position = {} + for i in range(num_instances): + # def check_tasks(): + # i = s + # def checkeach(): + + train_reader = instances[i].Reader(instances[i].config, phase='train') + instances[i].reader['train'] = train_reader + train_parad = instances[i].Paradigm(instances[i].config, phase='train', backbone_config=bb_conf) + instances[i].task_layer['train'] = train_parad + task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], instances[i].name) + task_attrs[i] = task_attr_from_reader _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') - - if inst.is_target: - if 'pred_file' not in inst.config: - inst.config['pred_file'] = '' - pred_reader = inst.Reader(inst.config, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) - inst.task_layer['pred'] = pred_parad - task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) + # merge reader input attrs from backbone and task_instances + # pred_joint_input_names = [] + # pred_joint_shape_and_dtypes = [] + if instances[i].is_target: + if 'pred_file' not in instances[i].config: + instances[i].config['pred_file'] = '' + pred_reader = instances[i].Reader(instances[i].config, phase='pred') + pred_parad = instances[i].Paradigm(instances[i].config, phase='pred', backbone_config=bb_conf) + instances[i].task_layer['pred'] = pred_parad + task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], instances[i].name) pred_task_attrs.append(task_attr_from_reader) _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - - # merge reader input attrs from backbone and task_instances - joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) + # pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + # return joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i], pred_joint_input_names, pred_joint_shape_and_dtypes + # return checkeach + # check_fns[i] = check_tasks() + joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i] = merge_input_attrs(train_backbone.inputs_attr, task_attrs[i]) + pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + + # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] if DEBUG: @@ -387,10 +424,11 @@ class Controller(object): print('joint input shape and dtypes:') print(joint_shape_and_dtypes) - # load data - for inst in instances: - print(inst.name+": preparing data...", end='') - inst.reader['train'].load_data() + # load data + data_fns={} + for i in range(num_instances): + print(instances[i].name+": preparing data...", end='') + instances[i].reader['train'].load_data() print('ok!') # merge dataset iterators and create net input vars @@ -406,65 +444,65 @@ class Controller(object): joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') self._joint_iterator_fn = joint_iterator_fn - input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] - pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) - net_inputs = create_net_inputs(input_attrs, async=False) - self._net_inputs = net_inputs - - # build backbone and task layers - train_prog = fluid.default_main_program() - train_init_prog = fluid.default_startup_program() - bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') - - assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) + input_attrs = {} + net_inputs = {} + bb_output_vars = {} + bb_output_fns = {} + # prepare predict vars for saving inference model + pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] pred_prog = fluid.Program() pred_init_prog = fluid.Program() + self._pred_prog = pred_prog with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): - pred_net_inputs = create_net_inputs(pred_input_attrs) - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') - - fluid.framework.switch_main_program(train_prog) - fluid.framework.switch_startup_program(train_init_prog) + pred_net_inputs = create_net_inputs(pred_input_attrs) + pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') + task_inputs = {} task_output_vars = {} - for inst in instances: - task_inputs = {'backbone': bb_output_vars} - task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) - task_inputs['reader'] = task_inputs_from_reader - - scope = inst.task_reuse_scope + '/' + task_fns = {} + + def get_loss(i): + input_attrs[i] = [[m, j, k] for m, (j,k) in zip(joint_input_names[i], joint_shape_and_dtypes[i])] + net_inputs[i] = create_net_inputs(input_attrs[i], async=False) + # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) + bb_output_vars[i] = train_backbone.build(net_inputs[i], scope_name='__paddlepalm_') + assert sorted(bb_output_vars[i].keys()) == sorted(train_backbone.outputs_attr.keys()) + + # build backbone and task layers + task_inputs[i] = {'backbone': bb_output_vars[i]} + task_inputs_from_reader = _decode_inputs(net_inputs[i], instances[i].name) + task_inputs[i]['reader'] = task_inputs_from_reader + + scope = instances[i].task_reuse_scope + '/' with fluid.unique_name.guard(scope): - - output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) - output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} - old = len(task_output_vars) # for debug - task_output_vars.update(output_vars) - assert len(task_output_vars) - old == len(output_vars) # for debug - # prepare predict vars for saving inference model - if inst.is_target: + output_vars = instances[i].build_task_layer(task_inputs[i], phase='train', scope=scope) + output_vars = {instances[i].name+'/'+key: val for key, val in output_vars.items()} + loss_var = output_vars[instances[i].name+'/loss'] + task_output_vars[i] = output_vars + + if instances[i].is_target: with fluid.program_guard(pred_prog, pred_init_prog): - cur_inputs = _decode_inputs(pred_net_inputs, inst.name) - inst.pred_input = cur_inputs + cur_inputs = _decode_inputs(pred_net_inputs, instances[i].name) + instances[i].pred_input = cur_inputs pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} - scope = inst.task_reuse_scope + '/' + scope = instances[i].task_reuse_scope + '/' with fluid.unique_name.guard(scope): - inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) - - - bb_fetches = {k: v.name for k,v in bb_output_vars.items()} - task_fetches = {k: v.name for k,v in task_output_vars.items()} - fetches = task_fetches - fetches['__task_id'] = net_inputs['__task_id'].name - - # compute loss - task_id_var = net_inputs['__task_id'] - task_id_vec = fluid.one_hot(task_id_var, num_instances) - losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) - loss = layers.reduce_sum(task_id_vec * losses) - + instances[i].build_task_layer(pred_task_inputs, phase='pred', scope=scope) + return loss_var + + for i in range(num_instances): + def task_loss(): + task_id = i + return lambda: get_loss(task_id) + task_fns[i] = task_loss() + + loss = layers.switch_case( + branch_index=branch, + branch_fns=task_fns + ) + self._switched_loss = loss.name main_reader = main_inst.reader['train'] num_examples = main_reader.num_examples @@ -498,13 +536,14 @@ class Controller(object): # prepare for train self.train_backbone = train_backbone - self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) + #self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) self.saver_program = fluid.default_main_program() + self.train_program = self.saver_program self.main_inst = main_inst - self.fetches = fetches self.has_init_train = True self.has_init_pred = True + self._net_inputs = net_inputs self.exe.run(fluid.default_startup_program()) print("\nRandomly initialize parameters...\n") @@ -525,6 +564,7 @@ class Controller(object): insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_prog = inst.load(infer_model_path) + # pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() if inst.reader['pred'] is None: pred_reader = inst.Reader(inst.config, phase='pred') inst.reader['pred'] = pred_reader @@ -556,8 +596,6 @@ class Controller(object): backbone = self.train_backbone train_program = self.train_program saver_program = self.saver_program - fetches = self.fetches - finish = [] for inst in instances: if inst.is_target: @@ -575,90 +613,45 @@ class Controller(object): return False return True - def pack_multicard_feed(iterator, net_inputs, dev_count): - ret = [] - mask = [] - for i in range(dev_count): - temp = {} - content, flag = next(iterator) - for q, var in net_inputs.items(): - temp[var.name] = content[q] - ret.append(temp) - mask.append(1 if flag else 0) - return ret, mask # do training - fetch_names, fetch_list = zip(*fetches.items()) - + fetch_names = {} + fetch_list = [] main_step = 0 # only count for main task global_step = 0 # count for all tasks epoch = 0 time_begin = time.time() backbone_buffer = [] - def multi_dev_reader(reader, dev_count): - def worker(reader, dev_count, queue): - dev_batches = [] - for index, data in enumerate(reader()): - if len(dev_batches) < dev_count: - dev_batches.append(data) - if len(dev_batches) == dev_count: - queue.put((dev_batches, 0)) - dev_batches = [] - # For the prediction of the remained batches, pad more batches to - # the number of devices and the padded samples would be removed in - # prediction outputs. - if len(dev_batches) > 0: - num_pad = dev_count - len(dev_batches) - for i in range(len(dev_batches), dev_count): - dev_batches.append(dev_batches[-1]) - queue.put((dev_batches, num_pad)) - queue.put(None) - - queue = Queue.Queue(dev_count*2) - p = Thread( - target=worker, args=(reader, dev_count, queue)) - p.daemon = True - p.start() - while True: - ret = queue.get() - if ret is not None: - batches, num_pad = ret - queue.task_done() - for batch in batches: - flag = num_pad == 0 - if num_pad > 0: - num_pad -= 1 - yield batch, flag - else: - break - queue.join() - - joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count) + feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) + distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) while not train_finish(): - feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count) + feed, mask, id = next(distribute_feeder) + + feed[0].update({'branch':np.array([id],dtype='int32')}) + fetch_list.append(self._switched_loss) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) + rt_loss = rt_outputs.pop() + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() - rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id - cur_task = instances[rt_task_id] + cur_task = instances[id] - backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} - backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) + # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} + # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) - task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) + # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} + # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) global_step += 1 cur_task.cur_train_step += 1 cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task_global_step)) + cur_task.save(suffix='.step'+str(cur_task_global_step), prog=self._pred_prog) if global_step % main_conf.get('print_every_n_steps', 5) == 0: - loss = rt_outputs[cur_task.name+'/loss'] + loss = rt_loss loss = np.mean(np.squeeze(loss)).tolist() time_end = time.time() @@ -671,7 +664,7 @@ class Controller(object): if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: print(cur_task.name+': train finished!') - cur_task.save() + cur_task.save(prog=self._pred_prog) if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: save_path = os.path.join(main_conf['save_path'], 'ckpt', @@ -716,18 +709,38 @@ class Controller(object): fetch_names, fetch_vars = inst.pred_fetch_list print('predicting...') - mapper = {k:v for k,v in inst.pred_input} + feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) + distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred') + buf = [] - for feed in inst.reader['pred'].iterator(): - feed = _encode_inputs(feed, inst.name, cand_set=mapper) - feed = {mapper[k]: v for k,v in feed.items()} + for feed, mask, id in distribute_feeder: + # print('before run') rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) + # print('after run') + splited_rt_outputs = [] + for item in rt_outputs: + splited_rt_outputs.append(np.split(item, len(mask))) + + # assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)] + # print(mask) + + while mask.pop() == False: + print(mask) + for item in splited_rt_outputs: + item.pop() + rt_outputs = [] + # print('cancat') + for item in splited_rt_outputs: + rt_outputs.append(np.concatenate(item)) + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') + # print('leave feeder') if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None + # print('epoch postprocess') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') @@ -741,7 +754,4 @@ if __name__ == '__main__': -__all__ = ["Controller"] - - - +__all__ = ["Controller"] \ No newline at end of file diff --git a/paddlepalm/reader/__init__.py b/paddlepalm/reader/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..43db376bb0f9aa151c5bfaa427d5de74c3287f2b 100644 --- a/paddlepalm/reader/__init__.py +++ b/paddlepalm/reader/__init__.py @@ -0,0 +1,8 @@ + +from paddle import fluid +import os +import multiprocessing + +gpu_dev_count = int(fluid.core.get_cuda_device_count()) +cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) +dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 376803b1a8fac2ccbae345a84da0df491600c5c8..96da1e67f0792771b9c5584d7821fb4b3c6016f9 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -16,6 +16,9 @@ from paddlepalm.interface import reader from paddlepalm.reader.utils.reader4ernie import ClassifyReader + + + class Reader(reader): def __init__(self, config, phase='train', dev_count=1, print_prefix=''): @@ -84,7 +87,6 @@ class Reader(reader): "task_ids_neg": [[-1, -1], 'int64'] }) return returns - def load_data(self): self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase) diff --git a/paddlepalm/reader/mlm.py b/paddlepalm/reader/mlm.py index e4dff3477f3ffd56864bcbaed439f7bd03716377..4eb0cbf2c0b5b530c393cea41bf54fd9d7bb34ab 100644 --- a/paddlepalm/reader/mlm.py +++ b/paddlepalm/reader/mlm.py @@ -83,8 +83,6 @@ class Reader(reader): return outputs for batch in self._data_generator(): - # print(np.shape(list_to_dict(batch)['token_ids'])) - # print(list_to_dict(batch)['mask_label'].tolist()) yield list_to_dict(batch) def get_epoch_outputs(self): diff --git a/paddlepalm/reader/mrc.py b/paddlepalm/reader/mrc.py index 2906b97ecb591fd6cc65f3a246c6d88e87dfccb8..4748184e4b3212a693459b6461a890f218828526 100644 --- a/paddlepalm/reader/mrc.py +++ b/paddlepalm/reader/mrc.py @@ -15,6 +15,7 @@ from paddlepalm.interface import reader from paddlepalm.reader.utils.reader4ernie import MRCReader +import numpy as np class Reader(reader): diff --git a/paddlepalm/reader/utils/__init__.py b/paddlepalm/reader/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5a662e41596646976dc3ada30575218cc055ef25 100644 --- a/paddlepalm/reader/utils/__init__.py +++ b/paddlepalm/reader/utils/__init__.py @@ -0,0 +1,9 @@ + + +from paddle import fluid +import os +import multiprocessing + +gpu_dev_count = int(fluid.core.get_cuda_device_count()) +cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) +dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count diff --git a/paddlepalm/reader/utils/mlm_batching.py b/paddlepalm/reader/utils/mlm_batching.py index 8d6061d42a4ea096d17e4aeb6d3df394e852deb6..f824e44c9700682c2c90e4be440325f37bdfbb7a 100644 --- a/paddlepalm/reader/utils/mlm_batching.py +++ b/paddlepalm/reader/utils/mlm_batching.py @@ -19,57 +19,76 @@ from __future__ import print_function import numpy as np -def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): +def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1): """ Add mask for batch_tokens, return out, mask_label, mask_pos; Note: mask_pos responding the batch_tokens after padded; """ max_len = max([len(sent) for sent in batch_tokens]) - mask_label = [] - mask_pos = [] - prob_mask = np.random.rand(total_token_num) - # Note: the first token is [CLS], so [low=1] - replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) - pre_sent_len = 0 - prob_index = 0 - for sent_index, sent in enumerate(batch_tokens): - mask_flag = False - prob_index += pre_sent_len - for token_index, token in enumerate(sent): - prob = prob_mask[prob_index + token_index] - if prob > 0.15: - continue - elif 0.03 < prob <= 0.15: - # mask - if token != SEP and token != CLS: + + multidev_batch_tokens = [] + multidev_mask_label = [] + multidev_mask_pos = [] + + big_batch_tokens = batch_tokens + stride = len(batch_tokens) // dev_count + if stride == 0: + return None, None, None + p = stride + + for i in range(dev_count): + batch_tokens = big_batch_tokens[p-stride:p] + p += stride + mask_label = [] + mask_pos = [] + prob_mask = np.random.rand(total_token_num) + # Note: the first token is [CLS], so [low=1] + replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) + pre_sent_len = 0 + prob_index = 0 + for sent_index, sent in enumerate(batch_tokens): + mask_flag = False + prob_index += pre_sent_len + for token_index, token in enumerate(sent): + prob = prob_mask[prob_index + token_index] + if prob > 0.15: + continue + elif 0.03 < prob <= 0.15: + # mask + if token != SEP and token != CLS: + mask_label.append(sent[token_index]) + sent[token_index] = MASK + mask_flag = True + mask_pos.append(sent_index * max_len + token_index) + elif 0.015 < prob <= 0.03: + # random replace + if token != SEP and token != CLS: + mask_label.append(sent[token_index]) + sent[token_index] = replace_ids[prob_index + token_index] + mask_flag = True + mask_pos.append(sent_index * max_len + token_index) + else: + # keep the original token + if token != SEP and token != CLS: + mask_label.append(sent[token_index]) + mask_pos.append(sent_index * max_len + token_index) + pre_sent_len = len(sent) + # ensure at least mask one word in a sentence + while not mask_flag: + token_index = int(np.random.randint(1, high=len(sent) - 1, size=1)) + if sent[token_index] != SEP and sent[token_index] != CLS: mask_label.append(sent[token_index]) sent[token_index] = MASK mask_flag = True mask_pos.append(sent_index * max_len + token_index) - elif 0.015 < prob <= 0.03: - # random replace - if token != SEP and token != CLS: - mask_label.append(sent[token_index]) - sent[token_index] = replace_ids[prob_index + token_index] - mask_flag = True - mask_pos.append(sent_index * max_len + token_index) - else: - # keep the original token - if token != SEP and token != CLS: - mask_label.append(sent[token_index]) - mask_pos.append(sent_index * max_len + token_index) - pre_sent_len = len(sent) - # ensure at least mask one word in a sentence - while not mask_flag: - token_index = int(np.random.randint(1, high=len(sent) - 1, size=1)) - if sent[token_index] != SEP and sent[token_index] != CLS: - mask_label.append(sent[token_index]) - sent[token_index] = MASK - mask_flag = True - mask_pos.append(sent_index * max_len + token_index) - mask_label = np.array(mask_label).astype("int64").reshape([-1]) - mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) - return batch_tokens, mask_label, mask_pos + mask_label = np.array(mask_label).astype("int64").reshape([-1]) + mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) + + multidev_batch_tokens.extend(batch_tokens) + multidev_mask_label.append(mask_label) + multidev_mask_pos.append(mask_pos) + + return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos def prepare_batch_data(insts, @@ -83,7 +102,8 @@ def prepare_batch_data(insts, task_id=0, return_input_mask=True, return_max_len=True, - return_num_token=False): + return_num_token=False, + dev_count=1): """ 1. generate Tensor of data 2. generate Tensor of position @@ -101,7 +121,8 @@ def prepare_batch_data(insts, vocab_size=voc_size, CLS=cls_id, SEP=sep_id, - MASK=mask_id) + MASK=mask_id, + dev_count=dev_count) # Second step: padding src_id, self_input_mask = pad_batch_data( out, @@ -125,7 +146,7 @@ def prepare_batch_data(insts, return_list = [ src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos ] - return return_list if len(return_list) > 1 else return_list[0] + return return_list def pad_batch_data(insts, diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index d33f961fbdb1b1dc8faccea3e99678ceb6bcd963..115d2546be8337db67e1029587ec71db1918c5d9 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -29,11 +29,14 @@ import six from io import open from collections import namedtuple +from . import gpu_dev_count +import paddlepalm as palm import paddlepalm.tokenizer.ernie_tokenizer as tokenization from paddlepalm.reader.utils.batching4ernie import pad_batch_data from paddlepalm.reader.utils.mlm_batching import prepare_batch_data + log = logging.getLogger(__name__) if six.PY3: @@ -478,14 +481,12 @@ class MaskLMReader(BaseReader): # max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。 return_input_mask=True, return_max_len=False, - return_num_token=False) + return_num_token=False, + dev_count=gpu_dev_count) - if len(all_dev_batches) < dev_count: - all_dev_batches.append(batch_data) - if len(all_dev_batches) == dev_count: - for batch in all_dev_batches: - yield batch - all_dev_batches = [] + # yield batch + for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size): + yield piece return wrapper @@ -952,11 +953,20 @@ class MRCReader(BaseReader): if to_append: batch_records.append(record) else: - yield self._pad_batch_records(batch_records, phase == "train") + # yield self._pad_batch_records(batch_records, phase == "train") + ds = ['s'] * 8 + for piece in palm.distribute.yield_pieces(\ + self._pad_batch_records(batch_records, phase == 'train'), + ds, batch_size): + yield piece batch_records, max_len = [record], len(record.token_ids) if phase == 'pred' and batch_records: - yield self._pad_batch_records(batch_records, phase == "train") + for piece in palm.distribute.yield_pieces(\ + self._pad_batch_records(batch_records, phase == 'train'), + ds, batch_size): + yield piece + def _pad_batch_records(self, batch_records, is_training): batch_token_ids = [record.token_ids for record in batch_records] @@ -1043,12 +1053,8 @@ class MRCReader(BaseReader): for batch_data in self._prepare_batch_data( features, batch_size, phase=phase): - if len(all_dev_batches) < dev_count: - all_dev_batches.append(batch_data) - if len(all_dev_batches) == dev_count: - for batch in all_dev_batches: - yield batch - all_dev_batches = [] + + yield batch_data return wrapper diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py index 091526912cdaa3db8f475f10cc4bf2409d2bd31e..288f2ce893047f1e2251a8b341a069960647a27b 100644 --- a/paddlepalm/task_instance.py +++ b/paddlepalm/task_instance.py @@ -34,12 +34,13 @@ class TaskInstance(object): self._name = name self._config = config self._verbose = verbose + self._id = id check_req_args(config, name) # parse Reader and Paradigm - reader_name = config['reader'] - reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) + self.reader_name = config['reader'] + reader_mod = importlib.import_module(READER_DIR + '.' + self.reader_name) Reader = getattr(reader_mod, 'Reader') parad_name = config['paradigm'] @@ -104,13 +105,18 @@ class TaskInstance(object): def epoch_postprocess(self, epoch_inputs, phase): return self._task_layer[phase].epoch_postprocess(epoch_inputs) - def save(self, suffix=''): + def save(self, suffix='', prog=None): dirpath = self._save_infermodel_path + suffix self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) - prog = fluid.default_main_program().clone() - fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) + # prog = fluid.default_main_program().clone() + if prog is not None: + save_prog = prog + else: + save_prog = fluid.default_main_program().clone() + + fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, save_prog) conf = {} for k, strv in self._save_protocol.items(): @@ -137,6 +143,10 @@ class TaskInstance(object): def name(self): return self._name + @property + def tid(self): + return self._id + @property def Reader(self): return self._Reader @@ -169,7 +179,7 @@ class TaskInstance(object): @property def pred_input(self): - return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) + return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list])) @pred_input.setter def pred_input(self, val): diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task_paradigm/mrc.py index ae36ecac82b94b6db89636645d226975fb94d5e6..001d0c7b2e537a61f00c8e05ec5459902ec1d835 100644 --- a/paddlepalm/task_paradigm/mrc.py +++ b/paddlepalm/task_paradigm/mrc.py @@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm): else: unique_id = inputs['reader']['unique_ids'] + enc_out = inputs['backbone']['encoder_outputs'] logits = fluid.layers.fc( input=enc_out, diff --git a/paddlepalm/utils/__init__.py b/paddlepalm/utils/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..139597f9cb07c5d48bed18984ec4747f4b4f3438 100644 --- a/paddlepalm/utils/__init__.py +++ b/paddlepalm/utils/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index c03e5fec04a62659f2252569c26ce27d0f01d63a..71585c303ea0dac8ba7fa64ba38dcf07e18a17d3 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 """ - pos_to_outname = {j:i for i,j in outname_to_pos.items()} - task_ids = range(len(iterators)) weights = [mr / float(sum(mrs)) for mr in mrs] if not keep_one_task: dev_count = 1 - results = _zero_batch(joint_shape_and_dtypes) - outbuf = {} + results = {} + pos_to_outname = {} for id in task_ids: + pos_to_outname[id] = {j:i for i,j in outname_to_pos[id].items()} + result = _zero_batch(joint_shape_and_dtypes[id]) + outbuf = {} outputs = next(iterators[id]) # dict type outbuf[id] = outputs prefix = iterator_prefixes[id] for outname, val in outputs.items(): task_outname = prefix + '/' + outname - if outname in outname_to_pos: - idx = outname_to_pos[outname] - val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') - results[idx] = val - - if task_outname in outname_to_pos: - idx = outname_to_pos[task_outname] - val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') - results[idx] = val + if outname in outname_to_pos[id]: + idx = outname_to_pos[id][outname] + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ') + result[idx] = val - fake_batch = results - dev_count_bak = dev_count + if task_outname in outname_to_pos[id]: + idx = outname_to_pos[id][task_outname] + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ') + result[idx] = val + results[id] = result def iterator(): v = verbose has_show_warn = False while True: id = np.random.choice(task_ids, p=weights) - results = fake_batch if v > 0: print('----- debug joint iterator -----') print('sampled task id: '+str(id)) @@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype for i in range(dev_count): - results[outname_to_pos['__task_id']] = task_id_tensor - assert outname_to_pos['__task_id'] == 0 + results[id][outname_to_pos[id]['__task_id']] = task_id_tensor + assert outname_to_pos[id]['__task_id'] == 0 if id in outbuf: outputs = outbuf[id] @@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype if 'token_ids' in outputs: val1 = len(outputs['token_ids']) val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ') - results[outname_to_pos['batch_size']] = val + results[id][outname_to_pos[id]['batch_size']] = val val2 = len(outputs['token_ids'][0]) val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64']) - results[outname_to_pos['seqlen']] = val + results[id][outname_to_pos[id]['seqlen']] = val val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64']) - results[outname_to_pos['batchsize_x_seqlen']] = val + results[id][outname_to_pos[id]['batchsize_x_seqlen']] = val else: if not has_show_warn: print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)') @@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype print('reader generate: '+outname) task_outname = prefix + '/' + outname - if outname in outname_to_pos: - idx = outname_to_pos[outname] + if outname in outname_to_pos[id]: + idx = outname_to_pos[id][outname] if v > 0: print(outname + ' is insert in idx ' + str(idx)) - val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') - results[idx] = val + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ') + results[id][idx] = val - if task_outname in outname_to_pos: - idx = outname_to_pos[task_outname] + if task_outname in outname_to_pos[id]: + idx = outname_to_pos[id][task_outname] if v > 0: print(task_outname + ' is insert in idx ' + str(idx)) - val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') - results[idx] = val + val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ') + results[id][idx] = val if v > 0: print('yielded batch len and shapes:') - print(len(results)) - for i in results: + print(len(results[id])) + for i in results[id]: print(np.shape(i)) print('') v -= 1 if return_type == 'list': - yield results + yield results[id] elif return_type == 'dict': temp = {} - for pos, i in enumerate(results): - temp[pos_to_outname[pos]] = i + for pos, i in enumerate(results[id]): + temp[pos_to_outname[id][pos]] = i yield temp return iterator