未验证 提交 553659a5 编写于 作者: C chengduo 提交者: GitHub

yolov3 is a special model, if num_trainers > 1, each process trian the completed dataset (#2605)

上级 931b0135
......@@ -272,13 +272,15 @@ class DataSetReader(object):
batch_out = [(im, im_id, im_shape)]
yield batch_out
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if mode == 'train' and num_trainers > 1:
assert shuffle_seed is not None, \
"If num_trainers > 1, the shuffle_seed must be set, because " \
"the order of batch data generated by reader " \
"must be the same in the respective processes."
reader = fluid.contrib.reader.distributed_batch_reader(reader)
# NOTE: yolov3 is a special model, if num_trainers > 1, each process
# trian the completed dataset.
# num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
# if mode == 'train' and num_trainers > 1:
# assert shuffle_seed is not None, \
# "If num_trainers > 1, the shuffle_seed must be set, because " \
# "the order of batch data generated by reader " \
# "must be the same in the respective processes."
# reader = fluid.contrib.reader.distributed_batch_reader(reader)
return reader
......
......@@ -47,7 +47,6 @@ import dist_utils
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1: return 1
......@@ -141,10 +140,10 @@ def train():
shuffle = True
if args.enable_ce:
shuffle = False
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because
# the order of batch data generated by reader
# must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
shuffle_seed = None
# NOTE: yolov3 is a special model, if num_trainers > 1, each process
# trian the completed dataset.
# if num_trainers > 1: shuffle_seed = 1
train_reader = reader.train(
input_size,
batch_size=cfg.batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册