提交 ea0664b9 编写于 作者: X xixiaoyao

release 0.3

上级 391a9bbb
import downloader
from mtl_controller import Controller
import distribute
from distribute import gpu_dev_count, cpu_dev_count
del interface
del task_instance
del default_settings
del utils
del mtl_controller
\ No newline at end of file
del mtl_controller
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
print('data in yield pieces')
print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
print('yielded pieces')
print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf
else:
break
queue.join()
......@@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from paddlepalm.utils.config_helper import PDConfig
from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.distribute import data_feeder
from paddlepalm.default_settings import *
from default_settings import *
from task_instance import TaskInstance, check_instances
import Queue
from threading import Thread
DEBUG=False
VERBOSE=0
......@@ -185,6 +184,20 @@ def _fit_attr(conf, fit_attr, strict=False):
return conf
def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data):
temp = {}
for q, var in net_inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q]
else:
temp[var.name] = data[q]
return temp
return feed_batch_process_fn
class Controller(object):
def __init__(self, config, task_dir='.', for_train=True):
......@@ -524,6 +537,7 @@ class Controller(object):
insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_prog = inst.load(infer_model_path)
pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
if inst.reader['pred'] is None:
pred_reader = inst.Reader(inst.config, phase='pred')
inst.reader['pred'] = pred_reader
......@@ -574,18 +588,6 @@ class Controller(object):
return False
return True
def pack_multicard_feed(iterator, net_inputs, dev_count):
ret = []
mask = []
for i in range(dev_count):
temp = {}
content, flag = next(iterator)
for q, var in net_inputs.items():
temp[var.name] = content[q]
ret.append(temp)
mask.append(1 if flag else 0)
return ret, mask
# do training
fetch_names, fetch_list = zip(*fetches.items())
......@@ -594,50 +596,18 @@ class Controller(object):
epoch = 0
time_begin = time.time()
backbone_buffer = []
def multi_dev_reader(reader, dev_count):
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*2)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
if ret is not None:
batches, num_pad = ret
queue.task_done()
for batch in batches:
flag = num_pad == 0
if num_pad > 0:
num_pad -= 1
yield batch, flag
else:
break
queue.join()
joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count)
feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs)
distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn)
# palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2)
while not train_finish():
feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count)
feed, mask = next(distribute_feeder)
rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list)
while mask.pop() == False:
rt_outputs.pop()
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist()
rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id
......@@ -714,19 +684,38 @@ class Controller(object):
fetch_names, fetch_vars = inst.pred_fetch_list
print('predicting...')
mapper = {k:v for k,v in inst.pred_input}
buf = []
for feed in inst.reader['pred'].iterator():
feed = _encode_inputs(feed, inst.name, cand_set=mapper)
feed = {mapper[k]: v for k,v in feed.items()}
feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input)
distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1)
buf = []
for feed, mask in distribute_feeder:
print('before run')
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
print('after run')
splited_rt_outputs = []
for item in rt_outputs:
splited_rt_outputs.append(np.split(item, len(mask)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
print(mask)
while mask.pop() == False:
print(mask)
for item in splited_rt_outputs:
item.pop()
rt_outputs = []
print('cancat')
for item in splited_rt_outputs:
rt_outputs.append(np.concatenate(item))
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred')
print('leave feeder')
if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs()
else:
reader_outputs = None
print('epoch postprocess')
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
......@@ -16,6 +16,18 @@
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
def match(learning_strategy='pointwise', siamese=False):
if siamese::
SiameseMatchReader(..., learning_strategy)
else:
ClassifyReader(..., learning_strategy)
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
......@@ -67,6 +79,28 @@ class Reader(reader):
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
if siamese:
if learning_strategy == 'pointwise':
{'token_ids_A':...,
'token_ids_B':...,
"position_ids_A": [[-1, -1, 1], 'int64'],
"position_ids_B": [[-1, -1, 1], 'int64'],
elif ...:
{
'token_ids_A',
'token_ids_B':...,
'tokeb_ids_A_neg':...
}
else:
if learning_strategy == 'pairwise':
return {
"token_ids": ...,
"token_ids_neg": ...
...
}
else:
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
......
......@@ -83,8 +83,6 @@ class Reader(reader):
return outputs
for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch)
def get_epoch_outputs(self):
......
......@@ -15,6 +15,7 @@
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader
import numpy as np
class Reader(reader):
......@@ -73,6 +74,7 @@ class Reader(reader):
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1, 1], 'int64'],
"unique_ids": [[-1, 1], 'int64'],
"end_positions": [[-1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
......@@ -108,6 +110,7 @@ class Reader(reader):
return outputs
for batch in self._data_generator():
print(len(list_to_dict(batch)))
yield list_to_dict(batch)
def get_epoch_outputs(self):
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
......@@ -19,57 +19,76 @@ from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
multidev_batch_tokens = []
multidev_mask_label = []
multidev_mask_pos = []
big_batch_tokens = batch_tokens
stride = len(batch_tokens) // dev_count
if stride == 0:
return None, None, None
p = stride
for i in range(dev_count):
batch_tokens = big_batch_tokens[p-stride:p]
p += stride
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
multidev_batch_tokens.extend(batch_tokens)
multidev_mask_label.append(mask_label)
multidev_mask_pos.append(mask_pos)
return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
def prepare_batch_data(insts,
......@@ -83,7 +102,8 @@ def prepare_batch_data(insts,
task_id=0,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
return_num_token=False,
dev_count=1):
"""
1. generate Tensor of data
2. generate Tensor of position
......@@ -101,7 +121,8 @@ def prepare_batch_data(insts,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
MASK=mask_id,
dev_count=dev_count)
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
......@@ -125,7 +146,7 @@ def prepare_batch_data(insts,
return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
]
return return_list if len(return_list) > 1 else return_list[0]
return return_list
def pad_batch_data(insts,
......
......@@ -29,11 +29,14 @@ import six
from io import open
from collections import namedtuple
from . import gpu_dev_count
import paddlepalm as palm
import paddlepalm.tokenizer.ernie_tokenizer as tokenization
from paddlepalm.reader.utils.batching4ernie import pad_batch_data
from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__)
if six.PY3:
......@@ -435,14 +438,12 @@ class MaskLMReader(BaseReader):
# max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return_num_token=False,
dev_count=gpu_dev_count)
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
# yield batch
for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size):
yield piece
return wrapper
......@@ -890,11 +891,20 @@ class MRCReader(BaseReader):
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records, phase == "train")
# yield self._pad_batch_records(batch_records, phase == "train")
ds = ['s'] * 8
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records:
yield self._pad_batch_records(batch_records, phase == "train")
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
def _pad_batch_records(self, batch_records, is_training):
batch_token_ids = [record.token_ids for record in batch_records]
......@@ -981,12 +991,8 @@ class MRCReader(BaseReader):
for batch_data in self._prepare_batch_data(
features, batch_size, phase=phase):
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data)
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
yield batch_data
return wrapper
......
......@@ -169,7 +169,7 @@ class TaskInstance(object):
@property
def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list])
return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list]))
@pred_input.setter
def pred_input(self, val):
......
......@@ -59,8 +59,13 @@ class TaskParadigm(task_paradigm):
def build(self, inputs, scope_name=""):
if self._is_training:
labels = inputs["reader"]["label_ids"]
if learning_strategy == 'pointwise':
labels = inputs["reader"]["label_ids"]
elif learning_strategy == 'pairwise':
inputs['backbone']["sentence_enbedding_neg"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
cls_feats = inputs["backbone"]["sentence_pair_embedding_neg"]
if self._is_training:
cls_feats = fluid.layers.dropout(
......
......@@ -82,9 +82,11 @@ class TaskParadigm(task_paradigm):
end_positions = fluid.layers.elementwise_min(end_positions, max_position)
start_positions.stop_gradient = True
end_positions.stop_gradient = True
fluid.layers.Print(start_positions)
else:
unique_id = inputs['reader']['unique_ids']
enc_out = inputs['backbone']['encoder_outputs']
logits = fluid.layers.fc(
input=enc_out,
......
......@@ -204,13 +204,13 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print(np.shape(i))
print('')
v -= 1
if return_type == 'list':
yield results
elif return_type == 'dict':
temp = {}
for pos, i in enumerate(results):
temp[pos_to_outname[pos]] = i
yield temp
if return_type == 'list':
yield results
elif return_type == 'dict':
temp = {}
for pos, i in enumerate(results):
temp[pos_to_outname[pos]] = i
yield temp
return iterator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册