提交 ea0664b9 编写于 作者: X xixiaoyao

release 0.3

上级 391a9bbb
import downloader import downloader
from mtl_controller import Controller from mtl_controller import Controller
import distribute
from distribute import gpu_dev_count, cpu_dev_count
del interface del interface
del task_instance del task_instance
del default_settings del default_settings
del utils del utils
del mtl_controller del mtl_controller
\ No newline at end of file
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from . import gpu_dev_count, cpu_dev_count
import Queue
from threading import Thread
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
def yield_pieces(data, distribute_strategy, batch_size):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert batch_size % dev_count == 0, "batch_size need to be integer times larger than dev_count."
print('data in yield pieces')
print(len(data))
assert type(data) == type(distribute_strategy), [type(data), type(distribute_strategy)]
assert len(data) == len(distribute_strategy), [len(data), len(distribute_strategy)]
if isinstance(data, dict):
keys = list(data.keys())
data_list = [data[i] for i in keys]
ds_list = [distribute_strategy[i] for i in keys]
else:
assert isinstance(data, list), "the input data must be a list or dict, and contained with multiple tensors."
data_list = data
ds_list = distribute_strategy
stride = batch_size // dev_count
p = stride
# while p < len(data_list) + stride:
while p <= batch_size:
temp = []
for d, s in zip(data_list, ds_list):
s = s.strip().lower()
if s == 's' or s == 'split':
if p - stride >= len(d):
print('WARNING: no more examples to feed empty devices')
temp = []
return
temp.append(d[p-stride:p])
elif s == 'u' or s == 'unstack':
assert len(d) <= dev_count, 'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if p//stride > len(d):
print('WARNING: no more examples to feed empty devices')
return
temp.append(d[p//stride-1])
elif s == 'c' or s == 'copy':
temp.append(d)
else:
raise NotImplementedError()
p += stride
if type(data) == dict:
yield dict(zip(*[keys, temp]))
else:
print('yielded pieces')
print(len(temp))
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*prefetch_steps)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
queue.task_done()
if ret is not None:
batches, num_pad = ret
batch_buf = []
flag_buf = []
for idx, batch in enumerate(batches):
# flag = num_pad == 0
flag = idx-len(batches) < -num_pad
# if num_pad > 0:
# num_pad -= 1
batch = postprocess_fn(batch)
batch_buf.append(batch)
flag_buf.append(flag)
yield batch_buf, flag_buf
else:
break
queue.join()
...@@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint ...@@ -31,12 +31,11 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.config_helper import PDConfig
from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.distribute import data_feeder
from paddlepalm.default_settings import * from default_settings import *
from task_instance import TaskInstance, check_instances from task_instance import TaskInstance, check_instances
import Queue
from threading import Thread
DEBUG=False DEBUG=False
VERBOSE=0 VERBOSE=0
...@@ -185,6 +184,20 @@ def _fit_attr(conf, fit_attr, strict=False): ...@@ -185,6 +184,20 @@ def _fit_attr(conf, fit_attr, strict=False):
return conf return conf
def create_feed_batch_process_fn(net_inputs):
def feed_batch_process_fn(data):
temp = {}
for q, var in net_inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
temp[var] = data[q]
else:
temp[var.name] = data[q]
return temp
return feed_batch_process_fn
class Controller(object): class Controller(object):
def __init__(self, config, task_dir='.', for_train=True): def __init__(self, config, task_dir='.', for_train=True):
...@@ -524,6 +537,7 @@ class Controller(object): ...@@ -524,6 +537,7 @@ class Controller(object):
insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_prog = inst.load(infer_model_path) pred_prog = inst.load(infer_model_path)
pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
if inst.reader['pred'] is None: if inst.reader['pred'] is None:
pred_reader = inst.Reader(inst.config, phase='pred') pred_reader = inst.Reader(inst.config, phase='pred')
inst.reader['pred'] = pred_reader inst.reader['pred'] = pred_reader
...@@ -574,18 +588,6 @@ class Controller(object): ...@@ -574,18 +588,6 @@ class Controller(object):
return False return False
return True return True
def pack_multicard_feed(iterator, net_inputs, dev_count):
ret = []
mask = []
for i in range(dev_count):
temp = {}
content, flag = next(iterator)
for q, var in net_inputs.items():
temp[var.name] = content[q]
ret.append(temp)
mask.append(1 if flag else 0)
return ret, mask
# do training # do training
fetch_names, fetch_list = zip(*fetches.items()) fetch_names, fetch_list = zip(*fetches.items())
...@@ -594,50 +596,18 @@ class Controller(object): ...@@ -594,50 +596,18 @@ class Controller(object):
epoch = 0 epoch = 0
time_begin = time.time() time_begin = time.time()
backbone_buffer = [] backbone_buffer = []
def multi_dev_reader(reader, dev_count):
def worker(reader, dev_count, queue):
dev_batches = []
for index, data in enumerate(reader()):
if len(dev_batches) < dev_count:
dev_batches.append(data)
if len(dev_batches) == dev_count:
queue.put((dev_batches, 0))
dev_batches = []
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if len(dev_batches) > 0:
num_pad = dev_count - len(dev_batches)
for i in range(len(dev_batches), dev_count):
dev_batches.append(dev_batches[-1])
queue.put((dev_batches, num_pad))
queue.put(None)
queue = Queue.Queue(dev_count*2)
p = Thread(
target=worker, args=(reader, dev_count, queue))
p.daemon = True
p.start()
while True:
ret = queue.get()
if ret is not None:
batches, num_pad = ret
queue.task_done()
for batch in batches:
flag = num_pad == 0
if num_pad > 0:
num_pad -= 1
yield batch, flag
else:
break
queue.join()
joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count) feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs)
distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn)
# palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2)
while not train_finish(): while not train_finish():
feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count) feed, mask = next(distribute_feeder)
rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list)
while mask.pop() == False:
rt_outputs.pop()
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist()
rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id
...@@ -714,19 +684,38 @@ class Controller(object): ...@@ -714,19 +684,38 @@ class Controller(object):
fetch_names, fetch_vars = inst.pred_fetch_list fetch_names, fetch_vars = inst.pred_fetch_list
print('predicting...') print('predicting...')
mapper = {k:v for k,v in inst.pred_input} feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input)
buf = [] distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1)
for feed in inst.reader['pred'].iterator():
feed = _encode_inputs(feed, inst.name, cand_set=mapper)
feed = {mapper[k]: v for k,v in feed.items()}
buf = []
for feed, mask in distribute_feeder:
print('before run')
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
print('after run')
splited_rt_outputs = []
for item in rt_outputs:
splited_rt_outputs.append(np.split(item, len(mask)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
print(mask)
while mask.pop() == False:
print(mask)
for item in splited_rt_outputs:
item.pop()
rt_outputs = []
print('cancat')
for item in splited_rt_outputs:
rt_outputs.append(np.concatenate(item))
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred') inst.postprocess(rt_outputs, phase='pred')
print('leave feeder')
if inst.task_layer['pred'].epoch_inputs_attrs: if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs() reader_outputs = inst.reader['pred'].get_epoch_outputs()
else: else:
reader_outputs = None reader_outputs = None
print('epoch postprocess')
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -16,6 +16,18 @@ ...@@ -16,6 +16,18 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader from paddlepalm.reader.utils.reader4ernie import ClassifyReader
def match(learning_strategy='pointwise', siamese=False):
if siamese::
SiameseMatchReader(..., learning_strategy)
else:
ClassifyReader(..., learning_strategy)
class Reader(reader): class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''): def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
...@@ -67,6 +79,28 @@ class Reader(reader): ...@@ -67,6 +79,28 @@ class Reader(reader):
"label_ids": [[-1,1], 'int64'], "label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'] "task_ids": [[-1, -1, 1], 'int64']
} }
if siamese:
if learning_strategy == 'pointwise':
{'token_ids_A':...,
'token_ids_B':...,
"position_ids_A": [[-1, -1, 1], 'int64'],
"position_ids_B": [[-1, -1, 1], 'int64'],
elif ...:
{
'token_ids_A',
'token_ids_B':...,
'tokeb_ids_A_neg':...
}
else:
if learning_strategy == 'pairwise':
return {
"token_ids": ...,
"token_ids_neg": ...
...
}
else:
else: else:
return {"token_ids": [[-1, -1, 1], 'int64'], return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'], "position_ids": [[-1, -1, 1], 'int64'],
......
...@@ -83,8 +83,6 @@ class Reader(reader): ...@@ -83,8 +83,6 @@ class Reader(reader):
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch) yield list_to_dict(batch)
def get_epoch_outputs(self): def get_epoch_outputs(self):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddlepalm.interface import reader from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader from paddlepalm.reader.utils.reader4ernie import MRCReader
import numpy as np
class Reader(reader): class Reader(reader):
...@@ -73,6 +74,7 @@ class Reader(reader): ...@@ -73,6 +74,7 @@ class Reader(reader):
"segment_ids": [[-1, -1, 1], 'int64'], "segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1, 1], 'int64'], "start_positions": [[-1, 1], 'int64'],
"unique_ids": [[-1, 1], 'int64'],
"end_positions": [[-1, 1], 'int64'], "end_positions": [[-1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'] "task_ids": [[-1, -1, 1], 'int64']
} }
...@@ -108,6 +110,7 @@ class Reader(reader): ...@@ -108,6 +110,7 @@ class Reader(reader):
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
print(len(list_to_dict(batch)))
yield list_to_dict(batch) yield list_to_dict(batch)
def get_epoch_outputs(self): def get_epoch_outputs(self):
......
from paddle import fluid
import os
import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = gpu_dev_count if gpu_dev_count > 0 else cpu_dev_count
...@@ -19,57 +19,76 @@ from __future__ import print_function ...@@ -19,57 +19,76 @@ from __future__ import print_function
import numpy as np import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
""" """
Add mask for batch_tokens, return out, mask_label, mask_pos; Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded; Note: mask_pos responding the batch_tokens after padded;
""" """
max_len = max([len(sent) for sent in batch_tokens]) max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = [] multidev_batch_tokens = []
prob_mask = np.random.rand(total_token_num) multidev_mask_label = []
# Note: the first token is [CLS], so [low=1] multidev_mask_pos = []
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0 big_batch_tokens = batch_tokens
prob_index = 0 stride = len(batch_tokens) // dev_count
for sent_index, sent in enumerate(batch_tokens): if stride == 0:
mask_flag = False return None, None, None
prob_index += pre_sent_len p = stride
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index] for i in range(dev_count):
if prob > 0.15: batch_tokens = big_batch_tokens[p-stride:p]
continue p += stride
elif 0.03 < prob <= 0.15: mask_label = []
# mask mask_pos = []
if token != SEP and token != CLS: prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index]) mask_label.append(sent[token_index])
sent[token_index] = MASK sent[token_index] = MASK
mask_flag = True mask_flag = True
mask_pos.append(sent_index * max_len + token_index) mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03: mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
# random replace mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
if token != SEP and token != CLS:
mask_label.append(sent[token_index]) multidev_batch_tokens.extend(batch_tokens)
sent[token_index] = replace_ids[prob_index + token_index] multidev_mask_label.append(mask_label)
mask_flag = True multidev_mask_pos.append(mask_pos)
mask_pos.append(sent_index * max_len + token_index)
else: return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts, def prepare_batch_data(insts,
...@@ -83,7 +102,8 @@ def prepare_batch_data(insts, ...@@ -83,7 +102,8 @@ def prepare_batch_data(insts,
task_id=0, task_id=0,
return_input_mask=True, return_input_mask=True,
return_max_len=True, return_max_len=True,
return_num_token=False): return_num_token=False,
dev_count=1):
""" """
1. generate Tensor of data 1. generate Tensor of data
2. generate Tensor of position 2. generate Tensor of position
...@@ -101,7 +121,8 @@ def prepare_batch_data(insts, ...@@ -101,7 +121,8 @@ def prepare_batch_data(insts,
vocab_size=voc_size, vocab_size=voc_size,
CLS=cls_id, CLS=cls_id,
SEP=sep_id, SEP=sep_id,
MASK=mask_id) MASK=mask_id,
dev_count=dev_count)
# Second step: padding # Second step: padding
src_id, self_input_mask = pad_batch_data( src_id, self_input_mask = pad_batch_data(
out, out,
...@@ -125,7 +146,7 @@ def prepare_batch_data(insts, ...@@ -125,7 +146,7 @@ def prepare_batch_data(insts,
return_list = [ return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
] ]
return return_list if len(return_list) > 1 else return_list[0] return return_list
def pad_batch_data(insts, def pad_batch_data(insts,
......
...@@ -29,11 +29,14 @@ import six ...@@ -29,11 +29,14 @@ import six
from io import open from io import open
from collections import namedtuple from collections import namedtuple
from . import gpu_dev_count
import paddlepalm as palm
import paddlepalm.tokenizer.ernie_tokenizer as tokenization import paddlepalm.tokenizer.ernie_tokenizer as tokenization
from paddlepalm.reader.utils.batching4ernie import pad_batch_data from paddlepalm.reader.utils.batching4ernie import pad_batch_data
from paddlepalm.reader.utils.mlm_batching import prepare_batch_data from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if six.PY3: if six.PY3:
...@@ -435,14 +438,12 @@ class MaskLMReader(BaseReader): ...@@ -435,14 +438,12 @@ class MaskLMReader(BaseReader):
# max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。 # max_len=self.max_seq_len, # 注意,如果padding到最大长度,会导致mask_pos与实际位置不对应。因为mask pos是基于batch内最大长度来计算的。
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False,
dev_count=gpu_dev_count)
if len(all_dev_batches) < dev_count: # yield batch
all_dev_batches.append(batch_data) for piece in palm.distribute.yield_pieces(batch_data, ['s', 's', 's', 's', 's', 'u', 'u'], batch_size):
if len(all_dev_batches) == dev_count: yield piece
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
...@@ -890,11 +891,20 @@ class MRCReader(BaseReader): ...@@ -890,11 +891,20 @@ class MRCReader(BaseReader):
if to_append: if to_append:
batch_records.append(record) batch_records.append(record)
else: else:
yield self._pad_batch_records(batch_records, phase == "train") # yield self._pad_batch_records(batch_records, phase == "train")
ds = ['s'] * 8
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if phase == 'pred' and batch_records: if phase == 'pred' and batch_records:
yield self._pad_batch_records(batch_records, phase == "train") for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records, phase == 'train'),
ds, batch_size):
yield piece
def _pad_batch_records(self, batch_records, is_training): def _pad_batch_records(self, batch_records, is_training):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
...@@ -981,12 +991,8 @@ class MRCReader(BaseReader): ...@@ -981,12 +991,8 @@ class MRCReader(BaseReader):
for batch_data in self._prepare_batch_data( for batch_data in self._prepare_batch_data(
features, batch_size, phase=phase): features, batch_size, phase=phase):
if len(all_dev_batches) < dev_count:
all_dev_batches.append(batch_data) yield batch_data
if len(all_dev_batches) == dev_count:
for batch in all_dev_batches:
yield batch
all_dev_batches = []
return wrapper return wrapper
......
...@@ -169,7 +169,7 @@ class TaskInstance(object): ...@@ -169,7 +169,7 @@ class TaskInstance(object):
@property @property
def pred_input(self): def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) return dict(zip(*[self._pred_input_name_list, self._pred_input_varname_list]))
@pred_input.setter @pred_input.setter
def pred_input(self, val): def pred_input(self, val):
......
...@@ -59,8 +59,13 @@ class TaskParadigm(task_paradigm): ...@@ -59,8 +59,13 @@ class TaskParadigm(task_paradigm):
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
if self._is_training: if self._is_training:
labels = inputs["reader"]["label_ids"] if learning_strategy == 'pointwise':
labels = inputs["reader"]["label_ids"]
elif learning_strategy == 'pairwise':
inputs['backbone']["sentence_enbedding_neg"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"] cls_feats = inputs["backbone"]["sentence_pair_embedding"]
cls_feats = inputs["backbone"]["sentence_pair_embedding_neg"]
if self._is_training: if self._is_training:
cls_feats = fluid.layers.dropout( cls_feats = fluid.layers.dropout(
......
...@@ -82,9 +82,11 @@ class TaskParadigm(task_paradigm): ...@@ -82,9 +82,11 @@ class TaskParadigm(task_paradigm):
end_positions = fluid.layers.elementwise_min(end_positions, max_position) end_positions = fluid.layers.elementwise_min(end_positions, max_position)
start_positions.stop_gradient = True start_positions.stop_gradient = True
end_positions.stop_gradient = True end_positions.stop_gradient = True
fluid.layers.Print(start_positions)
else: else:
unique_id = inputs['reader']['unique_ids'] unique_id = inputs['reader']['unique_ids']
enc_out = inputs['backbone']['encoder_outputs'] enc_out = inputs['backbone']['encoder_outputs']
logits = fluid.layers.fc( logits = fluid.layers.fc(
input=enc_out, input=enc_out,
......
...@@ -204,13 +204,13 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype ...@@ -204,13 +204,13 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print(np.shape(i)) print(np.shape(i))
print('') print('')
v -= 1 v -= 1
if return_type == 'list': if return_type == 'list':
yield results yield results
elif return_type == 'dict': elif return_type == 'dict':
temp = {} temp = {}
for pos, i in enumerate(results): for pos, i in enumerate(results):
temp[pos_to_outname[pos]] = i temp[pos_to_outname[pos]] = i
yield temp yield temp
return iterator return iterator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册