提交 c2ec3397 编写于 作者: T typhoonzero

update

上级 50b190db
...@@ -66,7 +66,6 @@ def parse_args(): ...@@ -66,7 +66,6 @@ def parse_args():
add_arg('split_var', bool, True, "Split params on pserver.") add_arg('split_var', bool, True, "Split params on pserver.")
add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.") add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.")
add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.") add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.")
add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.") add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.")
#for dgc #for dgc
add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.") add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.")
...@@ -85,13 +84,11 @@ def get_device_num(): ...@@ -85,13 +84,11 @@ def get_device_num():
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n') device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num return device_num
def prepare_reader(is_train, pyreader, args, pass_id=0): def prepare_reader(is_train, pyreader, args, pass_id=1):
# NOTE: allways set reader infinite when nccl2 mode to balance data # NOTE: always use infinite reader for dist training
# between ranks
is_infinite = (args.update_method == "nccl2")
if is_train: if is_train:
reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id, reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id,
infinite=is_infinite) infinite=True)
else: else:
reader = val(data_dir=args.data_dir) reader = val(data_dir=args.data_dir)
if is_train: if is_train:
...@@ -335,8 +332,7 @@ def train_parallel(args): ...@@ -335,8 +332,7 @@ def train_parallel(args):
num_samples = 0 num_samples = 0
start_time = time.time() start_time = time.time()
batch_id = 1 batch_id = 1
# use pass_id+1 as per pass global shuffle for distributed training if pass_id == 0:
prepare_reader(True, train_pyreader, args, pass_id + 1)
train_pyreader.start() train_pyreader.start()
while True: while True:
try: try:
...@@ -355,11 +351,10 @@ def train_parallel(args): ...@@ -355,11 +351,10 @@ def train_parallel(args):
break break
num_samples += args.batch_size num_samples += args.batch_size
batch_id += 1 batch_id += 1
if (args.skip_unbalanced_data or args.update_method == "nccl2") and batch_id >= steps_per_pass: if batch_id >= steps_per_pass:
break break
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
train_pyreader.reset()
if pass_id >= args.start_test_pass: if pass_id >= args.start_test_pass:
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog) copyback_repeat_bn_params(train_prog)
...@@ -375,6 +370,7 @@ def train_parallel(args): ...@@ -375,6 +370,7 @@ def train_parallel(args):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog) fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog)
train_pyreader.reset()
startup_exe.close() startup_exe.close()
print("total train time: ", time.time() - over_all_start) print("total train time: ", time.time() - over_all_start)
......
...@@ -15,5 +15,7 @@ PADDLE_TRAINING_ROLE="TRAINER" \ ...@@ -15,5 +15,7 @@ PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \ PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \
PADDLE_TRAINER_ID="${i}" \ PADDLE_TRAINER_ID="${i}" \
FLAGS_selected_gpus="${i}" \ FLAGS_selected_gpus="${i}" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 --fp16 1 --scale_loss 8 &> logs/tr$i.log & python -u dist_train.py --model $MODEL --update_method nccl2 \
--batch_size 32 \
--fp16 0 --scale_loss 1 &> logs/tr$i.log &
done done
...@@ -12,7 +12,7 @@ np.random.seed(0) ...@@ -12,7 +12,7 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 102400 BUF_SIZE = 1024
DATA_DIR = 'data/ILSVRC2012' DATA_DIR = 'data/ILSVRC2012'
...@@ -131,7 +131,7 @@ def _reader_creator(file_list, ...@@ -131,7 +131,7 @@ def _reader_creator(file_list,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=DATA_DIR, data_dir=DATA_DIR,
pass_id_as_seed=0, pass_id_as_seed=1,
infinite=False): infinite=False):
def reader(): def reader():
with open(file_list) as flist: with open(file_list) as flist:
...@@ -176,7 +176,7 @@ def _reader_creator(file_list, ...@@ -176,7 +176,7 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR, pass_id_as_seed=0, infinite=False): def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
file_list = os.path.join(data_dir, 'train_list.txt') file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator( return _reader_creator(
file_list, file_list,
...@@ -185,7 +185,8 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=0, infinite=False): ...@@ -185,7 +185,8 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=0, infinite=False):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=data_dir, data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed) pass_id_as_seed=pass_id_as_seed,
infinite=infinite)
def val(data_dir=DATA_DIR): def val(data_dir=DATA_DIR):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册