提交 ba55793f 编写于 作者: W wangxiao1021

change to switch op based

上级 de37fd75
...@@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE ...@@ -632,7 +632,9 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
#### 文本匹配数据集reader工具:match #### 文本匹配数据集reader工具:match
该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`。数据集范例可参考`data/match4mrqa`中的数据集文件,格式形如 该reader完成文本匹配数据集的载入与处理,reader接受[tsv格式](https://en.wikipedia.org/wiki/Tab-separated_values)的数据集输入,数据集应该包含三列,对于`pointwise`的学习策略,其中,一列为样本标签`label`,其余两列分别为待匹配的文本`text_a`和文本`text_b`;对于`pairwise`的学习策略,其中,一列为待匹配的样本`text_a`,其余为其对应的正例`text_b`和负例`text_b_neg`。格式形如
1. 学习策略为`pointwise`:
```yaml ```yaml
label text_a text_b label text_a text_b
...@@ -642,10 +644,22 @@ label text_a text_b ...@@ -642,10 +644,22 @@ label text_a text_b
1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday. 1 What has Pakistan told phone companies? **[TAB]** Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday.
``` ```
2. 学习策略为`pairwise`:
```yaml
text_a text_b text_b_neg
arrg ... ubuntunoob and ubuntu_user ... your nicks are confusing ^^ i d say it was **[TAB]** how that ... dynamic size of the c ontainer another idea would be an installation on an ( external ) flash-stick/card **[TAB]** will try now thanks if you have ati and md0 - i m no further help btw
got an error message while installing __number__ need help ( initrmfs ) mount failure error do you see this grub no a little more info would help ;-) did you boot a cd or pen drive to install or install from windows was this a install from windows whi ch is called a wubi how much memory does the computer have memory=ram so you got installed no errors and get this on reboot so when did you get this error did you burn it as a image **[TAB]** were you able to check the md5sum of the iso here is alink on md5sum i suspect it may not be this but never hurts to check __url__ **[TAB]** you would have to capture the pcl convert with hp 2xx then print that so do i set up another printer in cups with that driver but pointed to output to my cups pdf printer or do i need to pipe it through the driver on a lower level somehow
okay i come from a windows background .. currently running v __number__ __number__ and having a video card ( ati ) issue ... if i have an issue like this ( in windows ) i would go to the vendor site locate a current driver and install in ubuntu it aut omatically downloaded a driver - this driver i assume does not come from the vendor site but rather a ubuntu repository of tes ted/approved drivers is that a correct assumption yes that is correct **[TAB]** so given the downloaded driver is not performing properly i went to ati and found they have a newer version driver what is the correct process to load the new version do i ne ed to uninstall ( how ) the old version the new version is a run file - i am not familiar with what is the issue you re having with the ubuntu-supplied driver **[TAB]** ls -ld __path__ __path__ __path__ __path__ wrxr-xr-x
hey he wanted excitement __url__ __url__ dapper multivers thank you so much now i can do apt-get build-dep mythtv and compile it myself np i cannot install those packages i am also needing them why ca n't you install them i just verified they re insta llable i am on a default dapper install with all extra repositories in sources list uncommented and cant then you do n't have the correct repo enabled **[TAB]** lame installed ( none ) apt-cache policy lame **[TAB]** i am using mercury ... i think it is be tter than amsn i lost the curiosity for this __number__ years ago but i ve back are you using a router
```
***注意:数据集的第一列必须为header,即标注每一列的列名*** ***注意:数据集的第一列必须为header,即标注每一列的列名***
reader的输出(生成器每次yield出的数据)包含以下字段: reader的输出(生成器每次yield出的数据)包含以下字段:
1. 学习策略为`pointwise`:
```yaml ```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。 token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本(文本对),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。 position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
...@@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE ...@@ -657,6 +671,22 @@ task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE
当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。 当处于预测阶段时,reader所yield出的数据不会包含`label_ids`字段。
2. 学习策略为`pairwise`:
```yaml
token_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条正样本(文本对text_a text_b),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids: 一个shape为[batch_size, seq_len]的矩阵,每行是一条样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
label_ids: 一个shape为[batch_size]的矩阵,其中的每个元素为该样本的类别标签,为0时表示两段文本不匹配,为1时代表构成匹配。
task_ids: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
token_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本(文本对text_a text_b_neg),其中的每个元素为文本对中的每个token对应的单词id,文本对使用`[SEP]`所对应的id隔开。
position_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,每行是一条负样本,其中的每个元素为文本中的每个token对应的位置id。
segment_ids_neg: 一个shape为[batch_size, seq_len]的矩阵,在文本1(text_a)的token位置,元素取值为0;在文本2(text_b_neg)的token位置,元素取值为1。用于支持BERT、ERNIE等模型的输入。
input_mask_neg: 一个shape为[batch_size, seq_len]的矩阵,其中的每个元素为0或1,表示该位置是否是padding词(为1时代表是真实词,为0时代表是填充词)。
task_ids_neg: 一个shape为[batch_size, seq_len]的全0矩阵,用于支持ERNIE模型的输入。
```
#### 机器阅读理解数据集reader工具:mrc #### 机器阅读理解数据集reader工具:mrc
......
import downloader import downloader
from mtl_controller import Controller from mtl_controller import Controller
import distribute
from distribute import gpu_dev_count, cpu_dev_count
del interface del interface
del task_instance del task_instance
del default_settings del default_settings
del utils del utils
del mtl_controller del mtl_controller
\ No newline at end of file
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
# print('data in yield pieces')
# print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
# print('yielded pieces')
# print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
id = batches[0]['__task_id'][0][0] if phase == 'train' else -1
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch, id)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf, id
else:
break
queue.join()
...@@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint ...@@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.config_helper import PDConfig
from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.distribute import data_feeder
from paddlepalm.default_settings import * from default_settings import *
from task_instance import TaskInstance, check_instances from task_instance import TaskInstance, check_instances
import Queue
from threading import Thread
DEBUG=False DEBUG=False
VERBOSE=0 VERBOSE=0
...@@ -185,6 +184,27 @@ def _fit_attr(conf, fit_attr, strict=False): ...@@ -185,6 +184,27 @@ def _fit_attr(conf, fit_attr, strict=False):
return conf return conf
def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data, id=-1):
# temps = {}
# for i in range(len(net_inputs)):
temp = {}
inputs = net_inputs[id] if id != -1 else net_inputs
for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q]
else:
temp[var.name] = data[q]
# temps[i] = temp
return temp
return feed_batch_process_fn
class Controller(object): class Controller(object):
def __init__(self, config, task_dir='.', for_train=True): def __init__(self, config, task_dir='.', for_train=True):
...@@ -330,6 +350,7 @@ class Controller(object): ...@@ -330,6 +350,7 @@ class Controller(object):
dev_count = self.dev_count dev_count = self.dev_count
num_instances = len(instances) num_instances = len(instances)
mrs = self.mrs mrs = self.mrs
branch = fluid.data(name="branch",shape=[1],dtype='int32')
# set first_target/main task instance # set first_target/main task instance
main_inst = None main_inst = None
...@@ -349,35 +370,51 @@ class Controller(object): ...@@ -349,35 +370,51 @@ class Controller(object):
# create reader, task # create reader, task
# then check i/o across reader, backbone and task_layer # then check i/o across reader, backbone and task_layer
task_attrs = []
# check_fns = {}
task_attrs = {}
pred_task_attrs = [] pred_task_attrs = []
for inst in instances: joint_input_names = {}
train_reader = inst.Reader(inst.config, phase='train') joint_shape_and_dtypes = {}
inst.reader['train'] = train_reader name_to_position = {}
train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) for i in range(num_instances):
inst.task_layer['train'] = train_parad # def check_tasks():
task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) # i = s
task_attrs.append(task_attr_from_reader) # def checkeach():
train_reader = instances[i].Reader(instances[i].config, phase='train')
instances[i].reader['train'] = train_reader
train_parad = instances[i].Paradigm(instances[i].config, phase='train', backbone_config=bb_conf)
instances[i].task_layer['train'] = train_parad
task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], instances[i].name)
task_attrs[i] = task_attr_from_reader
_check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
_check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
_check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')
# merge reader input attrs from backbone and task_instances
if inst.is_target: # pred_joint_input_names = []
if 'pred_file' not in inst.config: # pred_joint_shape_and_dtypes = []
inst.config['pred_file'] = '' if instances[i].is_target:
pred_reader = inst.Reader(inst.config, phase='pred') if 'pred_file' not in instances[i].config:
pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) instances[i].config['pred_file'] = ''
inst.task_layer['pred'] = pred_parad pred_reader = instances[i].Reader(instances[i].config, phase='pred')
task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) pred_parad = instances[i].Paradigm(instances[i].config, phase='pred', backbone_config=bb_conf)
instances[i].task_layer['pred'] = pred_parad
task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], instances[i].name)
pred_task_attrs.append(task_attr_from_reader) pred_task_attrs.append(task_attr_from_reader)
_check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
_check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
_check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
# pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
# merge reader input attrs from backbone and task_instances # return joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i], pred_joint_input_names, pred_joint_shape_and_dtypes
joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) # return checkeach
# check_fns[i] = check_tasks()
joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i] = merge_input_attrs(train_backbone.inputs_attr, task_attrs[i])
pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
# shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
if DEBUG: if DEBUG:
...@@ -387,10 +424,11 @@ class Controller(object): ...@@ -387,10 +424,11 @@ class Controller(object):
print('joint input shape and dtypes:') print('joint input shape and dtypes:')
print(joint_shape_and_dtypes) print(joint_shape_and_dtypes)
# load data # load data
for inst in instances: data_fns={}
print(inst.name+": preparing data...", end='') for i in range(num_instances):
inst.reader['train'].load_data() print(instances[i].name+": preparing data...", end='')
instances[i].reader['train'].load_data()
print('ok!') print('ok!')
# merge dataset iterators and create net input vars # merge dataset iterators and create net input vars
...@@ -406,65 +444,65 @@ class Controller(object): ...@@ -406,65 +444,65 @@ class Controller(object):
joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict')
self._joint_iterator_fn = joint_iterator_fn self._joint_iterator_fn = joint_iterator_fn
input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] input_attrs = {}
pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] net_inputs = {}
# net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) bb_output_vars = {}
net_inputs = create_net_inputs(input_attrs, async=False) bb_output_fns = {}
self._net_inputs = net_inputs
# build backbone and task layers
train_prog = fluid.default_main_program()
train_init_prog = fluid.default_startup_program()
bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_')
assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys())
# prepare predict vars for saving inference model
pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)]
pred_prog = fluid.Program() pred_prog = fluid.Program()
pred_init_prog = fluid.Program() pred_init_prog = fluid.Program()
self._pred_prog = pred_prog
with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog):
pred_net_inputs = create_net_inputs(pred_input_attrs) pred_net_inputs = create_net_inputs(pred_input_attrs)
pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
fluid.framework.switch_main_program(train_prog)
fluid.framework.switch_startup_program(train_init_prog)
task_inputs = {}
task_output_vars = {} task_output_vars = {}
for inst in instances: task_fns = {}
task_inputs = {'backbone': bb_output_vars}
task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) def get_loss(i):
task_inputs['reader'] = task_inputs_from_reader input_attrs[i] = [[m, j, k] for m, (j,k) in zip(joint_input_names[i], joint_shape_and_dtypes[i])]
net_inputs[i] = create_net_inputs(input_attrs[i], async=False)
scope = inst.task_reuse_scope + '/' # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3)
bb_output_vars[i] = train_backbone.build(net_inputs[i], scope_name='__paddlepalm_')
assert sorted(bb_output_vars[i].keys()) == sorted(train_backbone.outputs_attr.keys())
# build backbone and task layers
task_inputs[i] = {'backbone': bb_output_vars[i]}
task_inputs_from_reader = _decode_inputs(net_inputs[i], instances[i].name)
task_inputs[i]['reader'] = task_inputs_from_reader
scope = instances[i].task_reuse_scope + '/'
with fluid.unique_name.guard(scope): with fluid.unique_name.guard(scope):
output_vars = instances[i].build_task_layer(task_inputs[i], phase='train', scope=scope)
output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) output_vars = {instances[i].name+'/'+key: val for key, val in output_vars.items()}
output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} loss_var = output_vars[instances[i].name+'/loss']
old = len(task_output_vars) # for debug task_output_vars[i] = output_vars
task_output_vars.update(output_vars)
assert len(task_output_vars) - old == len(output_vars) # for debug if instances[i].is_target:
# prepare predict vars for saving inference model
if inst.is_target:
with fluid.program_guard(pred_prog, pred_init_prog): with fluid.program_guard(pred_prog, pred_init_prog):
cur_inputs = _decode_inputs(pred_net_inputs, inst.name) cur_inputs = _decode_inputs(pred_net_inputs, instances[i].name)
inst.pred_input = cur_inputs instances[i].pred_input = cur_inputs
pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
scope = inst.task_reuse_scope + '/' scope = instances[i].task_reuse_scope + '/'
with fluid.unique_name.guard(scope): with fluid.unique_name.guard(scope):
inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) instances[i].build_task_layer(pred_task_inputs, phase='pred', scope=scope)
return loss_var
bb_fetches = {k: v.name for k,v in bb_output_vars.items()} for i in range(num_instances):
task_fetches = {k: v.name for k,v in task_output_vars.items()} def task_loss():
fetches = task_fetches task_id = i
fetches['__task_id'] = net_inputs['__task_id'].name return lambda: get_loss(task_id)
task_fns[i] = task_loss()
# compute loss
task_id_var = net_inputs['__task_id'] loss = layers.switch_case(
task_id_vec = fluid.one_hot(task_id_var, num_instances) branch_index=branch,
losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) branch_fns=task_fns
loss = layers.reduce_sum(task_id_vec * losses) )
self._switched_loss = loss.name
main_reader = main_inst.reader['train'] main_reader = main_inst.reader['train']
num_examples = main_reader.num_examples num_examples = main_reader.num_examples
...@@ -498,13 +536,14 @@ class Controller(object): ...@@ -498,13 +536,14 @@ class Controller(object):
# prepare for train # prepare for train
self.train_backbone = train_backbone self.train_backbone = train_backbone
self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) #self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
self.saver_program = fluid.default_main_program() self.saver_program = fluid.default_main_program()
self.train_program = self.saver_program
self.main_inst = main_inst self.main_inst = main_inst
self.fetches = fetches
self.has_init_train = True self.has_init_train = True
self.has_init_pred = True self.has_init_pred = True
self._net_inputs = net_inputs
self.exe.run(fluid.default_startup_program()) self.exe.run(fluid.default_startup_program())
print("\nRandomly initialize parameters...\n") print("\nRandomly initialize parameters...\n")
...@@ -525,6 +564,7 @@ class Controller(object): ...@@ -525,6 +564,7 @@ class Controller(object):
insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_prog = inst.load(infer_model_path) pred_prog = inst.load(infer_model_path)
# pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
if inst.reader['pred'] is None: if inst.reader['pred'] is None:
pred_reader = inst.Reader(inst.config, phase='pred') pred_reader = inst.Reader(inst.config, phase='pred')
inst.reader['pred'] = pred_reader inst.reader['pred'] = pred_reader
...@@ -556,8 +596,6 @@ class Controller(object): ...@@ -556,8 +596,6 @@ class Controller(object):
backbone = self.train_backbone backbone = self.train_backbone
train_program = self.train_program train_program = self.train_program
saver_program = self.saver_program saver_program = self.saver_program
fetches = self.fetches
finish = [] finish = []
for inst in instances: for inst in instances:
if inst.is_target: if inst.is_target:
...@@ -575,90 +613,45 @@ class Controller(object): ...@@ -575,90 +613,45 @@ class Controller(object):
return False return False
return True return True
def pack_multicard_feed(iterator, net_inputs, dev_count):
ret = []
mask = []
for i in range(dev_count):
temp = {}
content, flag = next(iterator)
for q, var in net_inputs.items():
temp[var.name] = content[q]
ret.append(temp)
mask.append(1 if flag else 0)
return ret, mask
# do training # do training
fetch_names, fetch_list = zip(*fetches.items()) fetch_names = {}
fetch_list = []
main_step = 0 # only count for main task main_step = 0 # only count for main task
global_step = 0 # count for all tasks global_step = 0 # count for all tasks
epoch = 0 epoch = 0
time_begin = time.time() time_begin = time.time()
backbone_buffer = [] backbone_buffer = []
def multi_dev_reader(reader, dev_count): feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs)
def worker(reader, dev_count, queue): distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn)
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*2)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
if ret is not None:
batches, num_pad = ret
queue.task_done()
for batch in batches:
flag = num_pad == 0
if num_pad > 0:
num_pad -= 1
yield batch, flag
else:
break
queue.join()
joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count)
while not train_finish(): while not train_finish():
feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count) feed, mask, id = next(distribute_feeder)
feed[0].update({'branch':np.array([id],dtype='int32')})
fetch_list.append(self._switched_loss)
rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list)
rt_loss = rt_outputs.pop()
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() cur_task = instances[id]
rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id
cur_task = instances[rt_task_id]
backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k}
backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs))
task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')}
instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs)
global_step += 1 global_step += 1
cur_task.cur_train_step += 1 cur_task.cur_train_step += 1
cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch
if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='.step'+str(cur_task_global_step)) cur_task.save(suffix='.step'+str(cur_task_global_step), prog=self._pred_prog)
if global_step % main_conf.get('print_every_n_steps', 5) == 0: if global_step % main_conf.get('print_every_n_steps', 5) == 0:
loss = rt_outputs[cur_task.name+'/loss'] loss = rt_loss
loss = np.mean(np.squeeze(loss)).tolist() loss = np.mean(np.squeeze(loss)).tolist()
time_end = time.time() time_end = time.time()
...@@ -671,7 +664,7 @@ class Controller(object): ...@@ -671,7 +664,7 @@ class Controller(object):
if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
print(cur_task.name+': train finished!') print(cur_task.name+': train finished!')
cur_task.save() cur_task.save(prog=self._pred_prog)
if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0:
save_path = os.path.join(main_conf['save_path'], 'ckpt', save_path = os.path.join(main_conf['save_path'], 'ckpt',
...@@ -716,18 +709,38 @@ class Controller(object): ...@@ -716,18 +709,38 @@ class Controller(object):
fetch_names, fetch_vars = inst.pred_fetch_list fetch_names, fetch_vars = inst.pred_fetch_list
print('predicting...') print('predicting...')
mapper = {k:v for k,v in inst.pred_input} feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input)
distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred')
buf = [] buf = []
for feed in inst.reader['pred'].iterator(): for feed, mask, id in distribute_feeder:
feed = _encode_inputs(feed, inst.name, cand_set=mapper) # print('before run')
feed = {mapper[k]: v for k,v in feed.items()}
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
# print('after run')
splited_rt_outputs = []
for item in rt_outputs:
splited_rt_outputs.append(np.split(item, len(mask)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
# print(mask)
while mask.pop() == False:
print(mask)
for item in splited_rt_outputs:
item.pop()
rt_outputs = []
# print('cancat')
for item in splited_rt_outputs:
rt_outputs.append(np.concatenate(item))
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred') inst.postprocess(rt_outputs, phase='pred')
# print('leave feeder')
if inst.task_layer['pred'].epoch_inputs_attrs: if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs() reader_outputs = inst.reader['pred'].get_epoch_outputs()
else: else:
reader_outputs = None reader_outputs = None
# print('epoch postprocess')
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
...@@ -741,7 +754,4 @@ if __name__ == '__main__': ...@@ -741,7 +754,4 @@ if __name__ == '__main__':
__all__ = ["Controller"] __all__ = ["Controller"]
\ No newline at end of file
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader): class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''): def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
...@@ -84,7 +87,6 @@ class Reader(reader): ...@@ -84,7 +87,6 @@ class Reader(reader):
"task_ids_neg": [[-1, -1], 'int64'] "task_ids_neg": [[-1, -1], 'int64']
}) })
return returns return returns
def load_data(self): def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase) self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
......
...@@ -83,8 +83,6 @@ class Reader(reader): ...@@ -83,8 +83,6 @@ class Reader(reader):
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch) yield list_to_dict(batch)
def get_epoch_outputs(self): def get_epoch_outputs(self):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader from paddlepalm.reader.utils.reader4ernie import MRCReader
import numpy as np
class Reader(reader): class Reader(reader):
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -19,57 +19,76 @@ from __future__ import print_function ...@@ -19,57 +19,76 @@ from __future__ import print_function
import numpy as np import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
""" """
Add mask for batch_tokens, return out, mask_label, mask_pos; Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded; Note: mask_pos responding the batch_tokens after padded;
""" """
max_len = max([len(sent) for sent in batch_tokens]) max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = [] multidev_batch_tokens = []
prob_mask = np.random.rand(total_token_num) multidev_mask_label = []
# Note: the first token is [CLS], so [low=1] multidev_mask_pos = []
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0 big_batch_tokens = batch_tokens
prob_index = 0 stride = len(batch_tokens) // dev_count
for sent_index, sent in enumerate(batch_tokens): if stride == 0:
mask_flag = False return None, None, None
prob_index += pre_sent_len p = stride
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index] for i in range(dev_count):
if prob > 0.15: batch_tokens = big_batch_tokens[p-stride:p]
continue p += stride
elif 0.03 < prob <= 0.15: mask_label = []
# mask mask_pos = []
if token != SEP and token != CLS: prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index]) mask_label.append(sent[token_index])
sent[token_index] = MASK sent[token_index] = MASK
mask_flag = True mask_flag = True
mask_pos.append(sent_index * max_len + token_index) mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03: mask_label = np.array(mask_label).astype("int64").reshape([-1])
# random replace mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
if token != SEP and token != CLS:
mask_label.append(sent[token_index]) multidev_batch_tokens.extend(batch_tokens)
sent[token_index] = replace_ids[prob_index + token_index] multidev_mask_label.append(mask_label)
mask_flag = True multidev_mask_pos.append(mask_pos)
mask_pos.append(sent_index * max_len + token_index)
else: return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts, def prepare_batch_data(insts,
...@@ -83,7 +102,8 @@ def prepare_batch_data(insts, ...@@ -83,7 +102,8 @@ def prepare_batch_data(insts,
task_id=0, task_id=0,
return_input_mask=True, return_input_mask=True,
return_max_len=True, return_max_len=True,
return_num_token=False): return_num_token=False,
dev_count=1):
""" """
1. generate Tensor of data 1. generate Tensor of data
2. generate Tensor of position 2. generate Tensor of position
...@@ -101,7 +121,8 @@ def prepare_batch_data(insts, ...@@ -101,7 +121,8 @@ def prepare_batch_data(insts,
vocab_size=voc_size, vocab_size=voc_size,
CLS=cls_id, CLS=cls_id,
SEP=sep_id, SEP=sep_id,
MASK=mask_id) MASK=mask_id,
dev_count=dev_count)
# Second step: padding # Second step: padding
src_id, self_input_mask = pad_batch_data( src_id, self_input_mask = pad_batch_data(
out, out,
...@@ -125,7 +146,7 @@ def prepare_batch_data(insts, ...@@ -125,7 +146,7 @@ def prepare_batch_data(insts,
return_list = [ return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
] ]
return return_list if len(return_list) > 1 else return_list[0] return return_list
def pad_batch_data(insts, def pad_batch_data(insts,
......
...@@ -29,11 +29,14 @@ import six ...@@ -29,11 +29,14 @@ import six
from io import open from io import open
from collections import namedtuple from collections import namedtuple
from . import gpu_dev_count
import paddlepalm as palm
import paddlepalm.tokenizer.ernie_tokenizer as tokenization import paddlepalm.tokenizer.ernie_tokenizer as tokenization
from paddlepalm.reader.utils.batching4ernie import pad_batch_data from paddlepalm.reader.utils.batching4ernie import pad_batch_data
from paddlepalm.reader.utils.mlm_batching import prepare_batch_data from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if six.PY3: if six.PY3:
...@@ -478,14 +481,12 @@ class MaskLMReader(BaseReader): ...@@ -478,14 +481,12 @@ class MaskLMReader(BaseReader):
# max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。 # max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False,
dev_count=gpu_dev_count)
if len(all_dev_batches) < dev_count: # yield batch
all_dev_batches.append(batch_data) for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size):
if len(all_dev_batches) == dev_count: yield piece
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
...@@ -952,11 +953,20 @@ class MRCReader(BaseReader): ...@@ -952,11 +953,20 @@ class MRCReader(BaseReader):
if to_append: if to_append:
batch_records.append(record) batch_records.append(record)
else: else:
yield self._pad_batch_records(batch_records, phase == "train") # yield self._pad_batch_records(batch_records, phase == "train")
ds = ['s'] * 8
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records: if phase == 'pred' and batch_records:
yield self._pad_batch_records(batch_records, phase == "train") for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
def _pad_batch_records(self, batch_records, is_training): def _pad_batch_records(self, batch_records, is_training):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
...@@ -1043,12 +1053,8 @@ class MRCReader(BaseReader): ...@@ -1043,12 +1053,8 @@ class MRCReader(BaseReader):
for batch_data in self._prepare_batch_data( for batch_data in self._prepare_batch_data(
features, batch_size, phase=phase): features, batch_size, phase=phase):
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data) yield batch_data
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
......
...@@ -34,12 +34,13 @@ class TaskInstance(object): ...@@ -34,12 +34,13 @@ class TaskInstance(object):
self._name = name self._name = name
self._config = config self._config = config
self._verbose = verbose self._verbose = verbose
self._id = id
check_req_args(config, name) check_req_args(config, name)
# parse Reader and Paradigm # parse Reader and Paradigm
reader_name = config['reader'] self.reader_name = config['reader']
reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) reader_mod = importlib.import_module(READER_DIR + '.' + self.reader_name)
Reader = getattr(reader_mod, 'Reader') Reader = getattr(reader_mod, 'Reader')
parad_name = config['paradigm'] parad_name = config['paradigm']
...@@ -104,13 +105,18 @@ class TaskInstance(object): ...@@ -104,13 +105,18 @@ class TaskInstance(object):
def epoch_postprocess(self, epoch_inputs, phase): def epoch_postprocess(self, epoch_inputs, phase):
return self._task_layer[phase].epoch_postprocess(epoch_inputs) return self._task_layer[phase].epoch_postprocess(epoch_inputs)
def save(self, suffix=''): def save(self, suffix='', prog=None):
dirpath = self._save_infermodel_path + suffix dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True)
prog = fluid.default_main_program().clone() # prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) if prog is not None:
save_prog = prog
else:
save_prog = fluid.default_main_program().clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, save_prog)
conf = {} conf = {}
for k, strv in self._save_protocol.items(): for k, strv in self._save_protocol.items():
...@@ -137,6 +143,10 @@ class TaskInstance(object): ...@@ -137,6 +143,10 @@ class TaskInstance(object):
def name(self): def name(self):
return self._name return self._name
@property
def tid(self):
return self._id
@property @property
def Reader(self): def Reader(self):
return self._Reader return self._Reader
...@@ -169,7 +179,7 @@ class TaskInstance(object): ...@@ -169,7 +179,7 @@ class TaskInstance(object):
@property @property
def pred_input(self): def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list]))
@pred_input.setter @pred_input.setter
def pred_input(self, val): def pred_input(self, val):
......
...@@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm): ...@@ -85,6 +85,7 @@ class TaskParadigm(task_paradigm):
else: else:
unique_id = inputs['reader']['unique_ids'] unique_id = inputs['reader']['unique_ids']
enc_out = inputs['backbone']['encoder_outputs'] enc_out = inputs['backbone']['encoder_outputs']
logits = fluid.layers.fc( logits = fluid.layers.fc(
input=enc_out, input=enc_out,
......
...@@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -111,41 +111,39 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查 joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
""" """
pos_to_outname = {j:i for i,j in outname_to_pos.items()}
task_ids = range(len(iterators)) task_ids = range(len(iterators))
weights = [mr / float(sum(mrs)) for mr in mrs] weights = [mr / float(sum(mrs)) for mr in mrs]
if not keep_one_task: if not keep_one_task:
dev_count = 1 dev_count = 1
results = _zero_batch(joint_shape_and_dtypes) results = {}
outbuf = {} pos_to_outname = {}
for id in task_ids: for id in task_ids:
pos_to_outname[id] = {j:i for i,j in outname_to_pos[id].items()}
result = _zero_batch(joint_shape_and_dtypes[id])
outbuf = {}
outputs = next(iterators[id]) # dict type outputs = next(iterators[id]) # dict type
outbuf[id] = outputs outbuf[id] = outputs
prefix = iterator_prefixes[id] prefix = iterator_prefixes[id]
for outname, val in outputs.items(): for outname, val in outputs.items():
task_outname = prefix + '/' + outname task_outname = prefix + '/' + outname
if outname in outname_to_pos: if outname in outname_to_pos[id]:
idx = outname_to_pos[outname] idx = outname_to_pos[id][outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[idx] = val result[idx] = val
if task_outname in outname_to_pos:
idx = outname_to_pos[task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
results[idx] = val
fake_batch = results if task_outname in outname_to_pos[id]:
dev_count_bak = dev_count idx = outname_to_pos[id][task_outname]
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
result[idx] = val
results[id] = result
def iterator(): def iterator():
v = verbose v = verbose
has_show_warn = False has_show_warn = False
while True: while True:
id = np.random.choice(task_ids, p=weights) id = np.random.choice(task_ids, p=weights)
results = fake_batch
if v > 0: if v > 0:
print('----- debug joint iterator -----') print('----- debug joint iterator -----')
print('sampled task id: '+str(id)) print('sampled task id: '+str(id))
...@@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -153,8 +151,8 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
for i in range(dev_count): for i in range(dev_count):
results[outname_to_pos['__task_id']] = task_id_tensor results[id][outname_to_pos[id]['__task_id']] = task_id_tensor
assert outname_to_pos['__task_id'] == 0 assert outname_to_pos[id]['__task_id'] == 0
if id in outbuf: if id in outbuf:
outputs = outbuf[id] outputs = outbuf[id]
...@@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -165,14 +163,14 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
if 'token_ids' in outputs: if 'token_ids' in outputs:
val1 = len(outputs['token_ids']) val1 = len(outputs['token_ids'])
val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ') val = _check_and_adapt_shape_dtype(np.array([val1], dtype='int64'), [[1], 'int64'], iterator_prefixes[id]+' tokenids: ')
results[outname_to_pos['batch_size']] = val results[id][outname_to_pos[id]['batch_size']] = val
val2 = len(outputs['token_ids'][0]) val2 = len(outputs['token_ids'][0])
val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['seqlen']] = val results[id][outname_to_pos[id]['seqlen']] = val
val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64']) val = _check_and_adapt_shape_dtype(np.array([val1*val2], dtype='int64'), [[1], 'int64'])
results[outname_to_pos['batchsize_x_seqlen']] = val results[id][outname_to_pos[id]['batchsize_x_seqlen']] = val
else: else:
if not has_show_warn: if not has_show_warn:
print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)') print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)')
...@@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -184,33 +182,33 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print('reader generate: '+outname) print('reader generate: '+outname)
task_outname = prefix + '/' + outname task_outname = prefix + '/' + outname
if outname in outname_to_pos: if outname in outname_to_pos[id]:
idx = outname_to_pos[outname] idx = outname_to_pos[id][outname]
if v > 0: if v > 0:
print(outname + ' is insert in idx ' + str(idx)) print(outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ')
results[idx] = val results[id][idx] = val
if task_outname in outname_to_pos: if task_outname in outname_to_pos[id]:
idx = outname_to_pos[task_outname] idx = outname_to_pos[id][task_outname]
if v > 0: if v > 0:
print(task_outname + ' is insert in idx ' + str(idx)) print(task_outname + ' is insert in idx ' + str(idx))
val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ') val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ')
results[idx] = val results[id][idx] = val
if v > 0: if v > 0:
print('yielded batch len and shapes:') print('yielded batch len and shapes:')
print(len(results)) print(len(results[id]))
for i in results: for i in results[id]:
print(np.shape(i)) print(np.shape(i))
print('') print('')
v -= 1 v -= 1
if return_type == 'list': if return_type == 'list':
yield results yield results[id]
elif return_type == 'dict': elif return_type == 'dict':
temp = {} temp = {}
for pos, i in enumerate(results): for pos, i in enumerate(results[id]):
temp[pos_to_outname[pos]] = i temp[pos_to_outname[id][pos]] = i
yield temp yield temp
return iterator return iterator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册