未验证 提交 553659a5 编写于 作者: C chengduo 提交者: GitHub

yolov3 is a special model, if num_trainers > 1, each process trian the completed dataset (#2605)

上级 931b0135
...@@ -272,13 +272,15 @@ class DataSetReader(object): ...@@ -272,13 +272,15 @@ class DataSetReader(object):
batch_out = [(im, im_id, im_shape)] batch_out = [(im, im_id, im_shape)]
yield batch_out yield batch_out
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) # NOTE: yolov3 is a special model, if num_trainers > 1, each process
if mode == 'train' and num_trainers > 1: # trian the completed dataset.
assert shuffle_seed is not None, \ # num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
"If num_trainers > 1, the shuffle_seed must be set, because " \ # if mode == 'train' and num_trainers > 1:
"the order of batch data generated by reader " \ # assert shuffle_seed is not None, \
"must be the same in the respective processes." # "If num_trainers > 1, the shuffle_seed must be set, because " \
reader = fluid.contrib.reader.distributed_batch_reader(reader) # "the order of batch data generated by reader " \
# "must be the same in the respective processes."
# reader = fluid.contrib.reader.distributed_batch_reader(reader)
return reader return reader
......
...@@ -47,7 +47,6 @@ import dist_utils ...@@ -47,7 +47,6 @@ import dist_utils
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def get_device_num(): def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card. # NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1: return 1 if num_trainers > 1: return 1
...@@ -141,10 +140,10 @@ def train(): ...@@ -141,10 +140,10 @@ def train():
shuffle = True shuffle = True
if args.enable_ce: if args.enable_ce:
shuffle = False shuffle = False
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because shuffle_seed = None
# the order of batch data generated by reader # NOTE: yolov3 is a special model, if num_trainers > 1, each process
# must be the same in the respective processes. # trian the completed dataset.
shuffle_seed = 1 if num_trainers > 1 else None # if num_trainers > 1: shuffle_seed = 1
train_reader = reader.train( train_reader = reader.train(
input_size, input_size,
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册