提交 ba55793f 编写于 作者: W wangxiao1021

change to switch op based

上级 de37fd75
......@@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
#### 文本匹配数据集reader工具:match
该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`。数据集范例可参考`data/match4mrqa`中的数据集文件,格式形如
该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,对于`pointwise`的学习策略,其中,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`;对于`pairwise`的学习策略,其中,一列为待匹配的样本`text_a`,其余为其对应的正例`text_b`和负例`text_b_neg`。格式形如
1. 学习策略为`pointwise`:
```yaml
label text_a text_b
......@@ -642,10 +644,22 @@ label text_a text_b
1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday.
```
2. 学习策略为`pairwise`:
```yaml
text_a text_b text_b_neg
arrg ... ubuntunoob and ubuntu_user ... your nicks are confusing ^^ i d say it was **[TAB]** how that ... dynamic size of the c ontainer another idea would be an installation on an ( external ) flash-stick/card **[TAB]** will try now thanks if you have ati and md0 - i m no further help btw
got an error message while installing __number__ need help ( initrmfs ) mount failure error do you see this grub no a little more info would help ;-) did you boot a cd or pen drive to install or install from windows was this a install from windows whi ch is called a wubi how much memory does the computer have memory=ram so you got installed no errors and get this on reboot so when did you get this error did you burn it as a image **[TAB]** were you able to check the md5sum of the iso here is alink on md5sum i suspect it may not be this but never hurts to check __url__ **[TAB]** you would have to capture the pcl convert with hp 2xx then print that so do i set up another printer in cups with that driver but pointed to output to my cups pdf printer or do i need to pipe it through the driver on a lower level somehow
okay i come from a windows background .. currently running v __number__ __number__ and having a video card ( ati ) issue ... if i have an issue like this ( in windows ) i would go to the vendor site locate a current driver and install in ubuntu it aut omatically downloaded a driver - this driver i assume does not come from the vendor site but rather a ubuntu repository of tes ted/approved drivers is that a correct assumption yes that is correct **[TAB]** so given the downloaded driver is not performing properly i went to ati and found they have a newer version driver what is the correct process to load the new version do i ne ed to uninstall ( how ) the old version the new version is a run file - i am not familiar with what is the issue you re having with the ubuntu-supplied driver **[TAB]** ls -ld __path__ __path__ __path__ __path__ wrxr-xr-x
hey he wanted excitement __url__ __url__ dapper multivers thank you so much now i can do apt-get build-dep mythtv and compile it myself np i cannot install those packages i am also needing them why ca n't you install them i just verified they re insta llable i am on a default dapper install with all extra repositories in sources list uncommented and cant then you do n't have the correct repo enabled **[TAB]** lame installed ( none ) apt-cache policy lame **[TAB]** i am using mercury ... i think it is be tter than amsn i lost the curiosity for this __number__ years ago but i ve back are you using a router
```
***注意:数据集的第一列必须为header,即标注每一列的列名***
reader的输出(生成器每次yield出的数据)包含以下字段:
1. 学习策略为`pointwise`:
```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
......@@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。
2. 学习策略为`pairwise`:
```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条正样本(文本对text_a text_b),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签,为0时表示两段文本不匹配,为1时代表构成匹配。
task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
token_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本(文本对text_a text_b_neg),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b_neg)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask_neg: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
task_ids_neg: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
```
#### 机器阅读理解数据集reader工具:mrc
......
import downloader
from mtl_controller import Controller
import distribute
from distribute import gpu_dev_count, cpu_dev_count
del interface
del task_instance
del default_settings
del utils
del mtl_controller
\ No newline at end of file
del mtl_controller
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
# print('data in yield pieces')
# print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
# print('yielded pieces')
# print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
id = batches[0]['__task_id'][0][0] if phase == 'train' else -1
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch, id)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf, id
else:
break
queue.join()
此差异已折叠。
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
......@@ -16,6 +16,9 @@
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
......@@ -84,7 +87,6 @@ class Reader(reader):
"task_ids_neg": [[-1, -1], 'int64']
})
return returns
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
......
......@@ -83,8 +83,6 @@ class Reader(reader):
return outputs
for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch)
def get_epoch_outputs(self):
......
......@@ -15,6 +15,7 @@
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader
import numpy as np
class Reader(reader):
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
......@@ -19,57 +19,76 @@ from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
multidev_batch_tokens = []
multidev_mask_label = []
multidev_mask_pos = []
big_batch_tokens = batch_tokens
stride = len(batch_tokens) // dev_count
if stride == 0:
return None, None, None
p = stride
for i in range(dev_count):
batch_tokens = big_batch_tokens[p-stride:p]
p += stride
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos
mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
multidev_batch_tokens.extend(batch_tokens)
multidev_mask_label.append(mask_label)
multidev_mask_pos.append(mask_pos)
return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
def prepare_batch_data(insts,
......@@ -83,7 +102,8 @@ def prepare_batch_data(insts,
task_id=0,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
return_num_token=False,
dev_count=1):
"""
1. generate Tensor of data
2. generate Tensor of position
......@@ -101,7 +121,8 @@ def prepare_batch_data(insts,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
MASK=mask_id,
dev_count=dev_count)
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
......@@ -125,7 +146,7 @@ def prepare_batch_data(insts,
return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
]
return return_list if len(return_list) > 1 else return_list[0]
return return_list
def pad_batch_data(insts,
......
......@@ -29,11 +29,14 @@ import six
from io import open
from collections import namedtuple
from . import gpu_dev_count
import paddlepalm as palm
import paddlepalm.tokenizer.ernie_tokenizer as tokenization
from paddlepalm.reader.utils.batching4ernie import pad_batch_data
from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__)
if six.PY3:
......@@ -478,14 +481,12 @@ class MaskLMReader(BaseReader):
# max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return_num_token=False,
dev_count=gpu_dev_count)
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
# yield batch
for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size):
yield piece
return wrapper
......@@ -952,11 +953,20 @@ class MRCReader(BaseReader):
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records, phase == "train")
# yield self._pad_batch_records(batch_records, phase == "train")
ds = ['s'] * 8
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records:
yield self._pad_batch_records(batch_records, phase == "train")
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
def _pad_batch_records(self, batch_records, is_training):
batch_token_ids = [record.token_ids for record in batch_records]
......@@ -1043,12 +1053,8 @@ class MRCReader(BaseReader):
for batch_data in self._prepare_batch_data(
features, batch_size, phase=phase):
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
yield batch_data
return wrapper
......
......@@ -34,12 +34,13 @@ class TaskInstance(object):
self._name = name
self._config = config
self._verbose = verbose
self._id = id
check_req_args(config, name)
# parse Reader and Paradigm
reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name)
self.reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + self.reader_name)
Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm']
......@@ -104,13 +105,18 @@ class TaskInstance(object):
def epoch_postprocess(self, epoch_inputs, phase):
return self._task_layer[phase].epoch_postprocess(epoch_inputs)
def save(self, suffix=''):
def save(self, suffix='', prog=None):
dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)
# prog = fluid.default_main_program().clone()
if prog is not None:
save_prog = prog
else:
save_prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, save_prog)
conf = {}
for k, strv in self._save_protocol.items():
......@@ -137,6 +143,10 @@ class TaskInstance(object):
def name(self):
return self._name
@property
def tid(self):
return self._id
@property
def Reader(self):
return self._Reader
......@@ -169,7 +179,7 @@ class TaskInstance(object):
@property
def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list])
return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list]))
@pred_input.setter
def pred_input(self, val):
......
......@@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm):
else:
unique_id = inputs['reader']['unique_ids']
enc_out = inputs['backbone']['encoder_outputs']
logits = fluid.layers.fc(
input=enc_out,
......
......@@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
"""
pos_to_outname = {j:i for i,j in outname_to_pos.items()}
task_ids = range(len(iterators))
weights = [mr / float(sum(mrs)) for mr in mrs]
if not keep_one_task:
dev_count = 1
results = _zero_batch(joint_shape_and_dtypes)
outbuf = {}
results = {}
pos_to_outname = {}
for id in task_ids:
pos_to_outname[id] = {j:i for i,j in outname_to_pos[id].items()}
result = _zero_batch(joint_shape_and_dtypes[id])
outbuf = {}
outputs = next(iterators[id]) # dict type
outbuf[id] = outputs
prefix = iterator_prefixes[id]
for outname, val in outputs.items():
task_outname = prefix + '/' + outname
if outname in outname_to_pos:
idx = outname_to_pos[outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
results[idx] = val
if task_outname in outname_to_pos:
idx = outname_to_pos[task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val
if outname in outname_to_pos[id]:
idx = outname_to_pos[id][outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
result[idx] = val
fake_batch = results
dev_count_bak = dev_count
if task_outname in outname_to_pos[id]:
idx = outname_to_pos[id][task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
result[idx] = val
results[id] = result
def iterator():
v = verbose
has_show_warn = False
while True:
id = np.random.choice(task_ids, p=weights)
results = fake_batch
if v > 0:
print('----- debug joint iterator -----')
print('sampled task id: '+str(id))
......@@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
for i in range(dev_count):
results[outname_to_pos['__task_id']] = task_id_tensor
assert outname_to_pos['__task_id'] == 0
results[id][outname_to_pos[id]['__task_id']] = task_id_tensor
assert outname_to_pos[id]['__task_id'] == 0
if id in outbuf:
outputs = outbuf[id]
......@@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if 'token_ids' in outputs:
val1 = len(outputs['token_ids'])
val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ')
results[outname_to_pos['batch_size']] = val
results[id][outname_to_pos[id]['batch_size']] = val
val2 = len(outputs['token_ids'][0])
val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['seqlen']] = val
results[id][outname_to_pos[id]['seqlen']] = val
val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['batchsize_x_seqlen']] = val
results[id][outname_to_pos[id]['batchsize_x_seqlen']] = val
else:
if not has_show_warn:
print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)')
......@@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print('reader generate: '+outname)
task_outname = prefix + '/' + outname
if outname in outname_to_pos:
idx = outname_to_pos[outname]
if outname in outname_to_pos[id]:
idx = outname_to_pos[id][outname]
if v > 0:
print(outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
results[idx] = val
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[id][idx] = val
if task_outname in outname_to_pos:
idx = outname_to_pos[task_outname]
if task_outname in outname_to_pos[id]:
idx = outname_to_pos[id][task_outname]
if v > 0:
print(task_outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
results[id][idx] = val
if v > 0:
print('yielded batch len and shapes:')
print(len(results))
for i in results:
print(len(results[id]))
for i in results[id]:
print(np.shape(i))
print('')
v -= 1
if return_type == 'list':
yield results
yield results[id]
elif return_type == 'dict':
temp = {}
for pos, i in enumerate(results):
temp[pos_to_outname[pos]] = i
for pos, i in enumerate(results[id]):
temp[pos_to_outname[id][pos]] = i
yield temp
return iterator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册