未验证 提交 3f0591a5 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #63 from wangxiao1021/api

fix bugs
......@@ -116,6 +116,7 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train', is
def decode_fake(nums, mask, bs):
bs //= dev_count
n_t = 0
for flag in mask:
if not flag:
......
......@@ -293,11 +293,18 @@ class Reader(object):
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records)
ds = ['s'] * 7
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records),
ds, batch_size):
yield piece
batch_records, max_len = [record], len(record.token_ids)
if phase == 'predict' and batch_records:
yield self._pad_batch_records(batch_records)
for piece in palm.distribute.yield_pieces(\
self._pad_batch_records(batch_records),
ds, batch_size):
yield piece
def get_num_examples(self, input_file=None, phase='train'):
if input_file is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册