提交 ba55793f 编写于 作者: W wangxiao1021

change to switch op based

上级 de37fd75
...@@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE ...@@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
#### 文本匹配数据集reader工具:match #### 文本匹配数据集reader工具:match
该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`。数据集范例可参考`data/match4mrqa`中的数据集文件,格式形如 该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,对于`pointwise`的学习策略,其中,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`;对于`pairwise`的学习策略,其中,一列为待匹配的样本`text_a`,其余为其对应的正例`text_b`和负例`text_b_neg`。格式形如
1. 学习策略为`pointwise`:
```yaml ```yaml
label text_a text_b label text_a text_b
...@@ -642,10 +644,22 @@ label text_a text_b ...@@ -642,10 +644,22 @@ label text_a text_b
1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday. 1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday.
``` ```
2. 学习策略为`pairwise`:
```yaml
text_a text_b text_b_neg
arrg ... ubuntunoob and ubuntu_user ... your nicks are confusing ^^ i d say it was **[TAB]** how that ... dynamic size of the c ontainer another idea would be an installation on an ( external ) flash-stick/card **[TAB]** will try now thanks if you have ati and md0 - i m no further help btw
got an error message while installing __number__ need help ( initrmfs ) mount failure error do you see this grub no a little more info would help ;-) did you boot a cd or pen drive to install or install from windows was this a install from windows whi ch is called a wubi how much memory does the computer have memory=ram so you got installed no errors and get this on reboot so when did you get this error did you burn it as a image **[TAB]** were you able to check the md5sum of the iso here is alink on md5sum i suspect it may not be this but never hurts to check __url__ **[TAB]** you would have to capture the pcl convert with hp 2xx then print that so do i set up another printer in cups with that driver but pointed to output to my cups pdf printer or do i need to pipe it through the driver on a lower level somehow
okay i come from a windows background .. currently running v __number__ __number__ and having a video card ( ati ) issue ... if i have an issue like this ( in windows ) i would go to the vendor site locate a current driver and install in ubuntu it aut omatically downloaded a driver - this driver i assume does not come from the vendor site but rather a ubuntu repository of tes ted/approved drivers is that a correct assumption yes that is correct **[TAB]** so given the downloaded driver is not performing properly i went to ati and found they have a newer version driver what is the correct process to load the new version do i ne ed to uninstall ( how ) the old version the new version is a run file - i am not familiar with what is the issue you re having with the ubuntu-supplied driver **[TAB]** ls -ld __path__ __path__ __path__ __path__ wrxr-xr-x
hey he wanted excitement __url__ __url__ dapper multivers thank you so much now i can do apt-get build-dep mythtv and compile it myself np i cannot install those packages i am also needing them why ca n't you install them i just verified they re insta llable i am on a default dapper install with all extra repositories in sources list uncommented and cant then you do n't have the correct repo enabled **[TAB]** lame installed ( none ) apt-cache policy lame **[TAB]** i am using mercury ... i think it is be tter than amsn i lost the curiosity for this __number__ years ago but i ve back are you using a router
```
***注意:数据集的第一列必须为header,即标注每一列的列名*** ***注意:数据集的第一列必须为header,即标注每一列的列名***
reader的输出(生成器每次yield出的数据)包含以下字段: reader的输出(生成器每次yield出的数据)包含以下字段:
1. 学习策略为`pointwise`:
```yaml ```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。 token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
...@@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE ...@@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。 当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。
2. 学习策略为`pairwise`:
```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条正样本(文本对text_a text_b),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签,为0时表示两段文本不匹配,为1时代表构成匹配。
task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
token_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本(文本对text_a text_b_neg),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b_neg)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask_neg: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
task_ids_neg: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
```
#### 机器阅读理解数据集reader工具:mrc #### 机器阅读理解数据集reader工具:mrc
......
import downloader import downloader
from mtl_controller import Controller from mtl_controller import Controller
import distribute
from distribute import gpu_dev_count, cpu_dev_count
del interface del interface
del task_instance del task_instance
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
# print('data in yield pieces')
# print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
# print('yielded pieces')
# print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
id = batches[0]['__task_id'][0][0] if phase == 'train' else -1
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch, id)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf, id
else:
break
queue.join()
此差异已折叠。
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader): class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''): def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
...@@ -85,7 +88,6 @@ class Reader(reader): ...@@ -85,7 +88,6 @@ class Reader(reader):
}) })
return returns return returns
def load_data(self): def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase) self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
......
...@@ -83,8 +83,6 @@ class Reader(reader): ...@@ -83,8 +83,6 @@ class Reader(reader):
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch) yield list_to_dict(batch)
def get_epoch_outputs(self): def get_epoch_outputs(self):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader from paddlepalm.reader.utils.reader4ernie import MRCReader
import numpy as np
class Reader(reader): class Reader(reader):
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -19,12 +19,26 @@ from __future__ import print_function ...@@ -19,12 +19,26 @@ from __future__ import print_function
import numpy as np import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
""" """
Add mask for batch_tokens, return out, mask_label, mask_pos; Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded; Note: mask_pos responding the batch_tokens after padded;
""" """
max_len = max([len(sent) for sent in batch_tokens]) max_len = max([len(sent) for sent in batch_tokens])
multidev_batch_tokens = []
multidev_mask_label = []
multidev_mask_pos = []
big_batch_tokens = batch_tokens
stride = len(batch_tokens) // dev_count
if stride == 0:
return None, None, None
p = stride
for i in range(dev_count):
batch_tokens = big_batch_tokens[p-stride:p]
p += stride
mask_label = [] mask_label = []
mask_pos = [] mask_pos = []
prob_mask = np.random.rand(total_token_num) prob_mask = np.random.rand(total_token_num)
...@@ -69,7 +83,12 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -69,7 +83,12 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
mask_pos.append(sent_index * max_len + token_index) mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1]) mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos
multidev_batch_tokens.extend(batch_tokens)
multidev_mask_label.append(mask_label)
multidev_mask_pos.append(mask_pos)
return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
def prepare_batch_data(insts, def prepare_batch_data(insts,
...@@ -83,7 +102,8 @@ def prepare_batch_data(insts, ...@@ -83,7 +102,8 @@ def prepare_batch_data(insts,
task_id=0, task_id=0,
return_input_mask=True, return_input_mask=True,
return_max_len=True, return_max_len=True,
return_num_token=False): return_num_token=False,
dev_count=1):
""" """
1. generate Tensor of data 1. generate Tensor of data
2. generate Tensor of position 2. generate Tensor of position
...@@ -101,7 +121,8 @@ def prepare_batch_data(insts, ...@@ -101,7 +121,8 @@ def prepare_batch_data(insts,
vocab_size=voc_size, vocab_size=voc_size,
CLS=cls_id, CLS=cls_id,
SEP=sep_id, SEP=sep_id,
MASK=mask_id) MASK=mask_id,
dev_count=dev_count)
# Second step: padding # Second step: padding
src_id, self_input_mask = pad_batch_data( src_id, self_input_mask = pad_batch_data(
out, out,
...@@ -125,7 +146,7 @@ def prepare_batch_data(insts, ...@@ -125,7 +146,7 @@ def prepare_batch_data(insts,
return_list = [ return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
] ]
return return_list if len(return_list) > 1 else return_list[0] return return_list
def pad_batch_data(insts, def pad_batch_data(insts,
......
...@@ -29,11 +29,14 @@ import six ...@@ -29,11 +29,14 @@ import six
from io import open from io import open
from collections import namedtuple from collections import namedtuple
from . import gpu_dev_count
import paddlepalm as palm
import paddlepalm.tokenizer.ernie_tokenizer as tokenization import paddlepalm.tokenizer.ernie_tokenizer as tokenization
from paddlepalm.reader.utils.batching4ernie import pad_batch_data from paddlepalm.reader.utils.batching4ernie import pad_batch_data
from paddlepalm.reader.utils.mlm_batching import prepare_batch_data from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if six.PY3: if six.PY3:
...@@ -478,14 +481,12 @@ class MaskLMReader(BaseReader): ...@@ -478,14 +481,12 @@ class MaskLMReader(BaseReader):
# max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。 # max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False,
dev_count=gpu_dev_count)
if len(all_dev_batches) < dev_count: # yield batch
all_dev_batches.append(batch_data) for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size):
if len(all_dev_batches) == dev_count: yield piece
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
...@@ -952,11 +953,20 @@ class MRCReader(BaseReader): ...@@ -952,11 +953,20 @@ class MRCReader(BaseReader):
if to_append: if to_append:
batch_records.append(record) batch_records.append(record)
else: else:
yield self._pad_batch_records(batch_records, phase == "train") # yield self._pad_batch_records(batch_records, phase == "train")
ds = ['s'] * 8
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records: if phase == 'pred' and batch_records:
yield self._pad_batch_records(batch_records, phase == "train") for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
def _pad_batch_records(self, batch_records, is_training): def _pad_batch_records(self, batch_records, is_training):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
...@@ -1043,12 +1053,8 @@ class MRCReader(BaseReader): ...@@ -1043,12 +1053,8 @@ class MRCReader(BaseReader):
for batch_data in self._prepare_batch_data( for batch_data in self._prepare_batch_data(
features, batch_size, phase=phase): features, batch_size, phase=phase):
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data) yield batch_data
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
......
...@@ -34,12 +34,13 @@ class TaskInstance(object): ...@@ -34,12 +34,13 @@ class TaskInstance(object):
self._name = name self._name = name
self._config = config self._config = config
self._verbose = verbose self._verbose = verbose
self._id = id
check_req_args(config, name) check_req_args(config, name)
# parse Reader and Paradigm # parse Reader and Paradigm
reader_name = config['reader'] self.reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) reader_mod = importlib.import_module(READER_DIR + '.' + self.reader_name)
Reader = getattr(reader_mod, 'Reader') Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm'] parad_name = config['paradigm']
...@@ -104,13 +105,18 @@ class TaskInstance(object): ...@@ -104,13 +105,18 @@ class TaskInstance(object):
def epoch_postprocess(self, epoch_inputs, phase): def epoch_postprocess(self, epoch_inputs, phase):
return self._task_layer[phase].epoch_postprocess(epoch_inputs) return self._task_layer[phase].epoch_postprocess(epoch_inputs)
def save(self, suffix=''): def save(self, suffix='', prog=None):
dirpath = self._save_infermodel_path + suffix dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
prog = fluid.default_main_program().clone() # prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) if prog is not None:
save_prog = prog
else:
save_prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, save_prog)
conf = {} conf = {}
for k, strv in self._save_protocol.items(): for k, strv in self._save_protocol.items():
...@@ -137,6 +143,10 @@ class TaskInstance(object): ...@@ -137,6 +143,10 @@ class TaskInstance(object):
def name(self): def name(self):
return self._name return self._name
@property
def tid(self):
return self._id
@property @property
def Reader(self): def Reader(self):
return self._Reader return self._Reader
...@@ -169,7 +179,7 @@ class TaskInstance(object): ...@@ -169,7 +179,7 @@ class TaskInstance(object):
@property @property
def pred_input(self): def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list]))
@pred_input.setter @pred_input.setter
def pred_input(self, val): def pred_input(self, val):
......
...@@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm): ...@@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm):
else: else:
unique_id = inputs['reader']['unique_ids'] unique_id = inputs['reader']['unique_ids']
enc_out = inputs['backbone']['encoder_outputs'] enc_out = inputs['backbone']['encoder_outputs']
logits = fluid.layers.fc( logits = fluid.layers.fc(
input=enc_out, input=enc_out,
......
...@@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
""" """
pos_to_outname = {j:i for i,j in outname_to_pos.items()}
task_ids = range(len(iterators)) task_ids = range(len(iterators))
weights = [mr / float(sum(mrs)) for mr in mrs] weights = [mr / float(sum(mrs)) for mr in mrs]
if not keep_one_task: if not keep_one_task:
dev_count = 1 dev_count = 1
results = _zero_batch(joint_shape_and_dtypes) results = {}
outbuf = {} pos_to_outname = {}
for id in task_ids: for id in task_ids:
pos_to_outname[id] = {j:i for i,j in outname_to_pos[id].items()}
result = _zero_batch(joint_shape_and_dtypes[id])
outbuf = {}
outputs = next(iterators[id]) # dict type outputs = next(iterators[id]) # dict type
outbuf[id] = outputs outbuf[id] = outputs
prefix = iterator_prefixes[id] prefix = iterator_prefixes[id]
for outname, val in outputs.items(): for outname, val in outputs.items():
task_outname = prefix + '/' + outname task_outname = prefix + '/' + outname
if outname in outname_to_pos: if outname in outname_to_pos[id]:
idx = outname_to_pos[outname] idx = outname_to_pos[id][outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[idx] = val result[idx] = val
if task_outname in outname_to_pos: if task_outname in outname_to_pos[id]:
idx = outname_to_pos[task_outname] idx = outname_to_pos[id][task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
results[idx] = val result[idx] = val
results[id] = result
fake_batch = results
dev_count_bak = dev_count
def iterator(): def iterator():
v = verbose v = verbose
has_show_warn = False has_show_warn = False
while True: while True:
id = np.random.choice(task_ids, p=weights) id = np.random.choice(task_ids, p=weights)
results = fake_batch
if v > 0: if v > 0:
print('----- debug joint iterator -----') print('----- debug joint iterator -----')
print('sampled task id: '+str(id)) print('sampled task id: '+str(id))
...@@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
for i in range(dev_count): for i in range(dev_count):
results[outname_to_pos['__task_id']] = task_id_tensor results[id][outname_to_pos[id]['__task_id']] = task_id_tensor
assert outname_to_pos['__task_id'] == 0 assert outname_to_pos[id]['__task_id'] == 0
if id in outbuf: if id in outbuf:
outputs = outbuf[id] outputs = outbuf[id]
...@@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if 'token_ids' in outputs: if 'token_ids' in outputs:
val1 = len(outputs['token_ids']) val1 = len(outputs['token_ids'])
val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ') val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ')
results[outname_to_pos['batch_size']] = val results[id][outname_to_pos[id]['batch_size']] = val
val2 = len(outputs['token_ids'][0]) val2 = len(outputs['token_ids'][0])
val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['seqlen']] = val results[id][outname_to_pos[id]['seqlen']] = val
val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['batchsize_x_seqlen']] = val results[id][outname_to_pos[id]['batchsize_x_seqlen']] = val
else: else:
if not has_show_warn: if not has_show_warn:
print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)') print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)')
...@@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print('reader generate: '+outname) print('reader generate: '+outname)
task_outname = prefix + '/' + outname task_outname = prefix + '/' + outname
if outname in outname_to_pos: if outname in outname_to_pos[id]:
idx = outname_to_pos[outname] idx = outname_to_pos[id][outname]
if v > 0: if v > 0:
print(outname + ' is insert in idx ' + str(idx)) print(outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[idx] = val results[id][idx] = val
if task_outname in outname_to_pos: if task_outname in outname_to_pos[id]:
idx = outname_to_pos[task_outname] idx = outname_to_pos[id][task_outname]
if v > 0: if v > 0:
print(task_outname + ' is insert in idx ' + str(idx)) print(task_outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
results[idx] = val results[id][idx] = val
if v > 0: if v > 0:
print('yielded batch len and shapes:') print('yielded batch len and shapes:')
print(len(results)) print(len(results[id]))
for i in results: for i in results[id]:
print(np.shape(i)) print(np.shape(i))
print('') print('')
v -= 1 v -= 1
if return_type == 'list': if return_type == 'list':
yield results yield results[id]
elif return_type == 'dict': elif return_type == 'dict':
temp = {} temp = {}
for pos, i in enumerate(results): for pos, i in enumerate(results[id]):
temp[pos_to_outname[pos]] = i temp[pos_to_outname[id][pos]] = i
yield temp yield temp
return iterator return iterator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册