未验证 提交 bf298439 编写于 作者: T Taylor Robie 提交者: GitHub

fix error when last shard is not assigned a batch (#5569)

上级 19d4eaaf
......@@ -206,7 +206,9 @@ def _construct_records(
num_workers: Number of multiprocessing workers to use for negative
generation.
cache_paths: Paths object with information of where to write files.
num_readers: The number of reader datasets in the input_fn.
num_readers: The number of reader datasets in the input_fn. This number is
approximate; fewer shards will be created if not all shards are assigned
batches. This can occur due to discretization in the assignment process.
num_neg: The number of false negatives per positive example.
num_positives: The number of positive examples. This value is used
to pre-allocate arrays while the imap is still running. (NumPy does not
......@@ -307,6 +309,10 @@ def _construct_records(
break
batches_by_file[current_file_id].append(current_batch_id)
# Drop shards which were not assigned batches
batches_by_file = [i for i in batches_by_file if i]
num_readers = len(batches_by_file)
if is_training:
# Empirically it is observed that placing the batch with repeated values at
# the start rather than the end improves convergence.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册