提交 ad7d3792 编写于 作者: W wangxiao1021

fix bugs

上级 d87bd156
...@@ -116,6 +116,7 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is ...@@ -116,6 +116,7 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is
def decode_fake(nums, mask, bs): def decode_fake(nums, mask, bs):
bs //= dev_count
n_t = 0 n_t = 0
for flag in mask: for flag in mask:
if not flag: if not flag:
......
...@@ -293,11 +293,18 @@ class Reader(object): ...@@ -293,11 +293,18 @@ class Reader(object):
if to_append: if to_append:
batch_records.append(record) batch_records.append(record)
else: else:
yield self._pad_batch_records(batch_records) ds = ['s'] * 7
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if phase == 'predict' and batch_records: if phase == 'predict' and batch_records:
yield self._pad_batch_records(batch_records) for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records),
ds, batch_size):
yield piece
def get_num_examples(self, input_file=None, phase='train'): def get_num_examples(self, input_file=None, phase='train'):
if input_file is None: if input_file is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册